mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
Compare commits
26 Commits
f6e7c200c0
...
remove-bes
Author | SHA1 | Date | |
---|---|---|---|
56180a5e26 | |||
1684ef279a | |||
aab21eb5e7 | |||
b997a31981 | |||
86d1963cc1 | |||
039d526d24 | |||
bcd059a384 | |||
0e57b4a9df | |||
98488e0946 | |||
f45e86571b | |||
f5827928a0 | |||
f853e091ea | |||
803ec0d856 | |||
7a0a615d50 | |||
c38cb69ec7 | |||
68ef15c686 | |||
3dd7fc2850 | |||
51ced65153 | |||
4bb883a6e6 | |||
f7846321e7 | |||
a944890ff1 | |||
521db3520a | |||
e2c97a805a | |||
d1d0407d3c | |||
824ff8c73e | |||
f15399d3d3 |
4
.github/workflows/slow-tests.yml
vendored
4
.github/workflows/slow-tests.yml
vendored
@ -68,7 +68,7 @@ jobs:
|
||||
CUDA_VISIBLE_DEVICES: "0,1"
|
||||
TEST_TYPE: "multi_gpu"
|
||||
container:
|
||||
image: pytorch/pytorch:2.6.0-cuda12.6-cudnn9-devel
|
||||
image: pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel
|
||||
options: --gpus all --shm-size "16gb"
|
||||
defaults:
|
||||
run:
|
||||
@ -115,6 +115,4 @@ jobs:
|
||||
source .venv/bin/activate
|
||||
uv pip install slack_sdk tabulate
|
||||
python scripts/log_reports.py >> $GITHUB_STEP_SUMMARY
|
||||
python scripts/log_example_reports.py --text_file_name temp_results_sft_tests.txt >> $GITHUB_STEP_SUMMARY
|
||||
python scripts/log_example_reports.py --text_file_name temp_results_dpo_tests.txt >> $GITHUB_STEP_SUMMARY
|
||||
rm *.txt
|
13
.github/workflows/tests.yml
vendored
13
.github/workflows/tests.yml
vendored
@ -11,11 +11,12 @@ on:
|
||||
- "scripts/**.py"
|
||||
- "tests/**.py"
|
||||
- "trl/**.py"
|
||||
- "setup.py"
|
||||
- "pyproject.toml"
|
||||
|
||||
env:
|
||||
TQDM_DISABLE: 1
|
||||
CI_SLACK_CHANNEL: ${{ secrets.CI_PUSH_MAIN_CHANNEL }}
|
||||
PYTORCH_CUDA_ALLOC_CONF: "expandable_segments:True"
|
||||
|
||||
jobs:
|
||||
check_code_quality:
|
||||
@ -41,7 +42,7 @@ jobs:
|
||||
runs-on:
|
||||
group: aws-g4dn-2xlarge
|
||||
container:
|
||||
image: pytorch/pytorch:2.6.0-cuda12.6-cudnn9-devel
|
||||
image: pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel
|
||||
options: --gpus all
|
||||
defaults:
|
||||
run:
|
||||
@ -93,7 +94,7 @@ jobs:
|
||||
runs-on:
|
||||
group: aws-g4dn-2xlarge
|
||||
container:
|
||||
image: pytorch/pytorch:2.6.0-cuda12.6-cudnn9-devel
|
||||
image: pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel
|
||||
options: --gpus all
|
||||
defaults:
|
||||
run:
|
||||
@ -128,7 +129,7 @@ jobs:
|
||||
uv pip install -U git+https://github.com/huggingface/accelerate.git
|
||||
uv pip install -U git+https://github.com/huggingface/datasets.git
|
||||
uv pip install -U git+https://github.com/huggingface/transformers.git
|
||||
|
||||
uv pip install -U git+https://github.com/huggingface/peft.git
|
||||
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
@ -149,7 +150,7 @@ jobs:
|
||||
runs-on:
|
||||
group: aws-g4dn-2xlarge
|
||||
container:
|
||||
image: pytorch/pytorch:2.6.0-cuda12.6-cudnn9-devel
|
||||
image: pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel
|
||||
options: --gpus all
|
||||
defaults:
|
||||
run:
|
||||
@ -201,7 +202,7 @@ jobs:
|
||||
runs-on:
|
||||
group: aws-g4dn-2xlarge
|
||||
container:
|
||||
image: pytorch/pytorch:2.6.0-cuda12.6-cudnn9-devel
|
||||
image: pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel
|
||||
options: --gpus all
|
||||
defaults:
|
||||
run:
|
||||
|
2
.github/workflows/tests_latest.yml
vendored
2
.github/workflows/tests_latest.yml
vendored
@ -16,7 +16,7 @@ jobs:
|
||||
runs-on:
|
||||
group: aws-g4dn-2xlarge
|
||||
container:
|
||||
image: pytorch/pytorch:2.6.0-cuda12.6-cudnn9-devel
|
||||
image: pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel
|
||||
options: --gpus all
|
||||
defaults:
|
||||
run:
|
||||
|
182
CONTRIBUTING.md
182
CONTRIBUTING.md
@ -1,15 +1,10 @@
|
||||
# How to contribute to TRL?
|
||||
|
||||
Everyone is welcome to contribute, and we value everybody's contribution. Code
|
||||
contributions are not the only way to help the community. Answering questions, helping
|
||||
others, and improving the documentation are also immensely valuable.
|
||||
Everyone is welcome to contribute, and we value everybody's contribution. Code contributions are not the only way to help the community. Answering questions, helping others, and improving the documentation are also immensely valuable.
|
||||
|
||||
It also helps us if you spread the word! Reference the library in blog posts
|
||||
about the awesome projects it made possible, shout out on Twitter every time it has
|
||||
helped you, or simply ⭐️ the repository to say thank you.
|
||||
It also helps us if you spread the word! Reference the library in blog posts about the awesome projects it made possible, shout out on Twitter every time it has helped you, or simply ⭐️ the repository to say thank you.
|
||||
|
||||
However you choose to contribute, please be mindful and respect our
|
||||
[code of conduct](https://github.com/huggingface/trl/blob/main/CODE_OF_CONDUCT.md).
|
||||
However you choose to contribute, please be mindful and respect our [code of conduct](https://github.com/huggingface/trl/blob/main/CODE_OF_CONDUCT.md).
|
||||
|
||||
**This guide was heavily inspired by the awesome [scikit-learn guide to contributing](https://github.com/scikit-learn/scikit-learn/blob/main/CONTRIBUTING.md).**
|
||||
|
||||
@ -22,9 +17,7 @@ There are several ways you can contribute to TRL:
|
||||
* Implement trainers for new post-training algorithms.
|
||||
* Contribute to the examples or the documentation.
|
||||
|
||||
If you don't know where to start, there is a special [Good First
|
||||
Issue](https://github.com/huggingface/trl/labels/%F0%9F%91%B6%20good%20first%20issue) listing. It will give you a list of
|
||||
open issues that are beginner-friendly and help you start contributing to open-source. The best way to do that is to open a Pull Request and link it to the issue that you'd like to work on. We try to give priority to opened PRs as we can easily track the progress of the fix, and if the contributor does not have time anymore, someone else can take the PR over.
|
||||
If you don't know where to start, there is a special [Good First Issue](https://github.com/huggingface/trl/labels/%F0%9F%91%B6%20good%20first%20issue) listing. It will give you a list of open issues that are beginner-friendly and help you start contributing to open-source. The best way to do that is to open a Pull Request and link it to the issue that you'd like to work on. We try to give priority to opened PRs as we can easily track the progress of the fix, and if the contributor does not have time anymore, someone else can take the PR over.
|
||||
|
||||
For something slightly more challenging, you can also take a look at the [Good Second Issue](https://github.com/huggingface/trl/labels/Good%20Second%20Issue) list. In general though, if you feel like you know what you're doing, go for it and we'll help you get there! 🚀
|
||||
|
||||
@ -48,14 +41,12 @@ Do your best to follow these guidelines when submitting a bug-related issue or a
|
||||
|
||||
The TRL library is robust and reliable thanks to users who report the problems they encounter.
|
||||
|
||||
Before you report an issue, we would really appreciate it if you could **make sure the bug was not
|
||||
already reported** (use the search bar on GitHub under Issues). Your issue should also be related to bugs in the library itself, and not your code.
|
||||
Before you report an issue, we would really appreciate it if you could **make sure the bug was not already reported** (use the search bar on GitHub under Issues). Your issue should also be related to bugs in the library itself, and not your code.
|
||||
|
||||
Once you've confirmed the bug hasn't already been reported, please include the following information in your issue so we can quickly resolve it:
|
||||
|
||||
* Your **OS type and version**, **Python**, **PyTorch**, **TRL** and **Transformers** versions.
|
||||
* A short, self-contained, code snippet that allows us to reproduce the bug in
|
||||
less than 30s.
|
||||
* A short, self-contained, code snippet that allows us to reproduce the bug in less than 30s.
|
||||
* The *full* traceback if an exception is raised.
|
||||
* Attach any other additional information, like screenshots, you think may help.
|
||||
|
||||
@ -106,29 +97,20 @@ We're always looking for improvements to the documentation that make it more cle
|
||||
|
||||
## Submitting a pull request (PR)
|
||||
|
||||
Before writing code, we strongly advise you to search through the existing PRs or
|
||||
issues to make sure that nobody is already working on the same thing. If you are
|
||||
unsure, it is always a good idea to open an issue to get some feedback.
|
||||
Before writing code, we strongly advise you to search through the existing PRs or issues to make sure that nobody is already working on the same thing. If you are unsure, it is always a good idea to open an issue to get some feedback.
|
||||
|
||||
You will need basic `git` proficiency to be able to contribute to
|
||||
TRL. `git` is not the easiest tool to use but it has the greatest
|
||||
manual. Type `git --help` in a shell and enjoy. If you prefer books, [Pro
|
||||
Git](https://git-scm.com/book/en/v2) is a very good reference.
|
||||
You will need basic `git` proficiency to be able to contribute to TRL. `git` is not the easiest tool to use but it has the greatest manual. Type `git --help` in a shell and enjoy. If you prefer books, [Pro Git](https://git-scm.com/book/en/v2) is a very good reference.
|
||||
|
||||
Follow these steps to start contributing:
|
||||
|
||||
1. Fork the [repository](https://github.com/huggingface/trl) by
|
||||
clicking on the 'Fork' button on the repository's page. This creates a copy of the code
|
||||
under your GitHub user account.
|
||||
1. Fork the [repository](https://github.com/huggingface/trl) by clicking on the 'Fork' button on the repository's page. This creates a copy of the code under your GitHub user account.
|
||||
|
||||
2. Clone your fork to your local disk, and add the base repository as a remote. The following command
|
||||
assumes you have your public SSH key uploaded to GitHub. See the following guide for more
|
||||
[information](https://docs.github.com/en/repositories/creating-and-managing-repositories/cloning-a-repository).
|
||||
2. Clone your fork to your local disk, and add the base repository as a remote. The following command assumes you have your public SSH key uploaded to GitHub. See the following guide for more [information](https://docs.github.com/en/repositories/creating-and-managing-repositories/cloning-a-repository).
|
||||
|
||||
```bash
|
||||
$ git clone git@github.com:<your Github handle>/trl.git
|
||||
$ cd trl
|
||||
$ git remote add upstream https://github.com/huggingface/trl.git
|
||||
git clone git@github.com:<your Github handle>/trl.git
|
||||
cd trl
|
||||
git remote add upstream https://github.com/huggingface/trl.git
|
||||
```
|
||||
|
||||
3. Create a new branch to hold your development changes, and do this for every new PR you work on.
|
||||
@ -136,15 +118,15 @@ Follow these steps to start contributing:
|
||||
Start by synchronizing your `main` branch with the `upstream/main` branch (more details in the [GitHub Docs](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/syncing-a-fork)):
|
||||
|
||||
```bash
|
||||
$ git checkout main
|
||||
$ git fetch upstream
|
||||
$ git merge upstream/main
|
||||
git checkout main
|
||||
git fetch upstream
|
||||
git merge upstream/main
|
||||
```
|
||||
|
||||
Once your `main` branch is synchronized, create a new branch from it:
|
||||
|
||||
```bash
|
||||
$ git checkout -b a-descriptive-name-for-my-changes
|
||||
git checkout -b a-descriptive-name-for-my-changes
|
||||
```
|
||||
|
||||
**Do not** work on the `main` branch.
|
||||
@ -152,32 +134,28 @@ Follow these steps to start contributing:
|
||||
4. Set up a development environment by running the following command in a conda or a virtual environment you've created for working on this library:
|
||||
|
||||
```bash
|
||||
$ pip install -e .[dev]
|
||||
pip install -e .[dev]
|
||||
```
|
||||
|
||||
(If TRL was already installed in the virtual environment, remove
|
||||
it with `pip uninstall trl` before reinstalling it.)
|
||||
(If TRL was already installed in the virtual environment, remove it with `pip uninstall trl` before reinstalling it.)
|
||||
|
||||
Alternatively, if you are using [Visual Studio Code](https://code.visualstudio.com/Download), the fastest way to get set up is by using
|
||||
the provided Dev Container. Documentation on how to get started with dev containers is available [here](https://code.visualstudio.com/docs/remote/containers).
|
||||
Alternatively, if you are using [Visual Studio Code](https://code.visualstudio.com/Download), the fastest way to get set up is by using the provided Dev Container. Check [the documentation on how to get started with dev containers](https://code.visualstudio.com/docs/remote/containers).
|
||||
|
||||
5. Develop the features on your branch.
|
||||
|
||||
As you work on the features, you should make sure that the test suite
|
||||
passes. You should run the tests impacted by your changes like this (see
|
||||
below an explanation regarding the environment variable):
|
||||
As you work on the features, you should make sure that the test suite passes. You should run the tests impacted by your changes like this (see below an explanation regarding the environment variable):
|
||||
|
||||
```bash
|
||||
$ pytest tests/<TEST_TO_RUN>.py
|
||||
```
|
||||
|
||||
> For the following commands leveraging the `make` utility.
|
||||
```bash
|
||||
pytest tests/<TEST_TO_RUN>.py
|
||||
```
|
||||
|
||||
You can also run the full suite with the following command.
|
||||
> For the following commands leveraging the `make` utility.
|
||||
|
||||
```bash
|
||||
$ make test
|
||||
```
|
||||
You can also run the full suite with the following command.
|
||||
|
||||
```bash
|
||||
make test
|
||||
```
|
||||
|
||||
TRL relies on `ruff` for maintaining consistent code formatting across its source files. Before submitting any PR, you should apply automatic style corrections and run code verification checks.
|
||||
|
||||
@ -186,59 +164,51 @@ Follow these steps to start contributing:
|
||||
To apply these checks and corrections in one step, use:
|
||||
|
||||
```bash
|
||||
$ make precommit
|
||||
make precommit
|
||||
```
|
||||
|
||||
This command runs the following:
|
||||
- Executes `pre-commit` hooks to automatically fix style issues with `ruff` and other tools.
|
||||
- Runs additional scripts such as adding copyright information.
|
||||
|
||||
* Executes `pre-commit` hooks to automatically fix style issues with `ruff` and other tools.
|
||||
* Runs additional scripts such as adding copyright information.
|
||||
|
||||
If you prefer to apply the style corrections separately or review them individually, the `pre-commit` hook will handle the formatting for the files in question.
|
||||
|
||||
Once you're happy with your changes, add changed files using `git add` and
|
||||
make a commit with `git commit` to record your changes locally:
|
||||
Once you're happy with your changes, add changed files using `git add` and make a commit with `git commit` to record your changes locally:
|
||||
|
||||
```bash
|
||||
$ git add modified_file.py
|
||||
$ git commit
|
||||
```
|
||||
```bash
|
||||
git add modified_file.py
|
||||
git commit
|
||||
```
|
||||
|
||||
Please write [good commit messages](https://chris.beams.io/posts/git-commit/).
|
||||
Please write [good commit messages](https://chris.beams.io/posts/git-commit/).
|
||||
|
||||
It is a good idea to sync your copy of the code with the original
|
||||
repository regularly. This way you can quickly account for changes:
|
||||
It is a good idea to sync your copy of the code with the original
|
||||
repository regularly. This way you can quickly account for changes:
|
||||
|
||||
```bash
|
||||
$ git fetch upstream
|
||||
$ git rebase upstream/main
|
||||
```
|
||||
```bash
|
||||
git fetch upstream
|
||||
git rebase upstream/main
|
||||
```
|
||||
|
||||
Push the changes to your account using:
|
||||
Push the changes to your account using:
|
||||
|
||||
```bash
|
||||
$ git push -u origin a-descriptive-name-for-my-changes
|
||||
```
|
||||
```bash
|
||||
git push -u origin a-descriptive-name-for-my-changes
|
||||
```
|
||||
|
||||
6. Once you are satisfied (**and the checklist below is happy too**), go to the
|
||||
webpage of your fork on GitHub. Click on 'Pull request' to send your changes
|
||||
to the project maintainers for review.
|
||||
6. Once you are satisfied (**and the checklist below is happy too**), go to the webpage of your fork on GitHub. Click on 'Pull request' to send your changes to the project maintainers for review.
|
||||
|
||||
7. It's ok if maintainers ask you for changes. It happens to core contributors too! To ensure everyone can review your changes in the pull request, work on your local branch and push the updates to your fork. They will automatically appear in the pull request.
|
||||
|
||||
|
||||
### Checklist
|
||||
|
||||
1. The title of your pull request should be a summary of its contribution;
|
||||
2. If your pull request addresses an issue, please mention the issue number in
|
||||
the pull request description to make sure they are linked (and people
|
||||
consulting the issue know you are working on it);
|
||||
3. To indicate a work in progress please prefix the title with `[WIP]`, or mark
|
||||
the PR as a draft PR. These are useful to avoid duplicated work, and to differentiate
|
||||
it from PRs ready to be merged;
|
||||
2. If your pull request addresses an issue, please mention the issue number in the pull request description to make sure they are linked (and people consulting the issue know you are working on it);
|
||||
3. To indicate a work in progress please prefix the title with `[WIP]`, or mark the PR as a draft PR. These are useful to avoid duplicated work, and to differentiate it from PRs ready to be merged;
|
||||
4. Make sure existing tests pass;
|
||||
5. Add high-coverage tests. No quality testing = no merge.
|
||||
|
||||
|
||||
### Tests
|
||||
|
||||
An extensive test suite is included to test the library behavior and several examples. Library tests can be found in
|
||||
@ -248,7 +218,7 @@ We use `pytest` to run the tests. From the root of the
|
||||
repository here's how to run tests with `pytest` for the library:
|
||||
|
||||
```bash
|
||||
$ python -m pytest -sv ./tests
|
||||
python -m pytest -sv ./tests
|
||||
```
|
||||
|
||||
That's how `make test` is implemented (without the `pip install` line)!
|
||||
@ -260,23 +230,23 @@ you're working on.
|
||||
|
||||
1. **Use defaults when appropriate**:
|
||||
|
||||
Provide default values unless the parameter's value varies significantly by use case. For example, datasets or models should not have defaults, but parameters like `learning_rate` should.
|
||||
Provide default values unless the parameter's value varies significantly by use case. For example, datasets or models should not have defaults, but parameters like `learning_rate` should.
|
||||
|
||||
2. **Prioritize proven defaults**:
|
||||
|
||||
Default values should align with those recommended in the original paper or method. Alternatives require strong evidence of superior performance in most cases.
|
||||
Default values should align with those recommended in the original paper or method. Alternatives require strong evidence of superior performance in most cases.
|
||||
|
||||
3. **Ensure safety and predictability**:
|
||||
|
||||
Defaults must be safe, expected and reliable. Avoid settings that could lead to surprising outcomes, such as excessive memory usage or poor performance in edge cases.
|
||||
Defaults must be safe, expected and reliable. Avoid settings that could lead to surprising outcomes, such as excessive memory usage or poor performance in edge cases.
|
||||
|
||||
4. **Balance consistency and flexibility**:
|
||||
|
||||
Aim for consistent defaults across similar functions or methods. However, consistency should not be preferred to point 2 or 3.
|
||||
Aim for consistent defaults across similar functions or methods. However, consistency should not be preferred to point 2 or 3.
|
||||
|
||||
5. **Opt-in for new features**:
|
||||
|
||||
Do not enable new features or improvements (e.g., novel loss functions) by default. Users should explicitly opt-in to use these.
|
||||
Do not enable new features or improvements (e.g., novel loss functions) by default. Users should explicitly opt-in to use these.
|
||||
|
||||
### Writing documentation
|
||||
|
||||
@ -318,26 +288,26 @@ def replicate_str(string: str, n: int, sep: str = " ") -> str:
|
||||
* Note that `Optional` means that the value can be `None`, and `*optional*` means that it is not required for the user to pass a value.
|
||||
E.g., for arguments that can't be `None` and aren't required:
|
||||
|
||||
```python
|
||||
```txt
|
||||
foo (`int`, *optional*, defaults to `4`):
|
||||
```
|
||||
|
||||
For arguments that can be `None` and are required:
|
||||
|
||||
```python
|
||||
```txt
|
||||
foo (`Optional[int]`):
|
||||
```
|
||||
|
||||
for arguments that can be `None` and aren't required:
|
||||
for arguments that can be `None` and aren't required (in this case, if the default value is `None`, you can omit it):
|
||||
|
||||
```python
|
||||
```txt
|
||||
foo (`Optional[int]`, *optional*):
|
||||
```
|
||||
|
||||
* **String Defaults:**
|
||||
* Ensured that default string values are wrapped in double quotes:
|
||||
|
||||
```python
|
||||
```txt
|
||||
defaults to `"foo"`
|
||||
```
|
||||
|
||||
@ -346,7 +316,7 @@ def replicate_str(string: str, n: int, sep: str = " ") -> str:
|
||||
* **Default Value Formatting:**
|
||||
* Consistently surrounded default values with backticks for improved formatting:
|
||||
|
||||
```python
|
||||
```txt
|
||||
defaults to `4`
|
||||
```
|
||||
|
||||
@ -383,8 +353,8 @@ Our approach to deprecation and backward compatibility is flexible and based on
|
||||
|
||||
When a feature or component is marked for deprecation, its use will emit a warning message. This warning will include:
|
||||
|
||||
- **Transition Guidance**: Instructions on how to migrate to the alternative solution or replacement.
|
||||
- **Removal Version**: The target version when the feature will be removed, providing users with a clear timeframe to transition.
|
||||
* **Transition Guidance**: Instructions on how to migrate to the alternative solution or replacement.
|
||||
* **Removal Version**: The target version when the feature will be removed, providing users with a clear timeframe to transition.
|
||||
|
||||
Example:
|
||||
|
||||
@ -398,9 +368,9 @@ Example:
|
||||
|
||||
The deprecation and removal schedule is based on each feature's usage and impact, with examples at two extremes:
|
||||
|
||||
- **Experimental or Low-Use Features**: For a feature that is experimental or has limited usage, backward compatibility may not be maintained between releases. Users should therefore anticipate potential breaking changes from one version to the next.
|
||||
* **Experimental or Low-Use Features**: For a feature that is experimental or has limited usage, backward compatibility may not be maintained between releases. Users should therefore anticipate potential breaking changes from one version to the next.
|
||||
|
||||
- **Widely-Used Components**: For a feature with high usage, we aim for a more gradual transition period of approximately **5 months**, generally scheduling deprecation around **5 minor releases** after the initial warning.
|
||||
* **Widely-Used Components**: For a feature with high usage, we aim for a more gradual transition period of approximately **5 months**, generally scheduling deprecation around **5 minor releases** after the initial warning.
|
||||
|
||||
These examples represent the two ends of a continuum. The specific timeline for each feature will be determined individually, balancing innovation with user stability needs.
|
||||
|
||||
@ -410,22 +380,22 @@ Warnings play a critical role in guiding users toward resolving potential issues
|
||||
|
||||
#### Definitions
|
||||
|
||||
- **Correct**: An operation is correct if it is valid, follows the intended approach, and aligns with the current best practices or guidelines within the codebase. This is the recommended or intended way to perform the operation.
|
||||
- **Supported**: An operation is supported if it is technically valid and works within the current codebase, but it may not be the most efficient, optimal, or recommended way to perform the task. This includes deprecated features or legacy approaches that still work but may be phased out in the future.
|
||||
* **Correct**: An operation is correct if it is valid, follows the intended approach, and aligns with the current best practices or guidelines within the codebase. This is the recommended or intended way to perform the operation.
|
||||
* **Supported**: An operation is supported if it is technically valid and works within the current codebase, but it may not be the most efficient, optimal, or recommended way to perform the task. This includes deprecated features or legacy approaches that still work but may be phased out in the future.
|
||||
|
||||
#### Choosing the right message
|
||||
|
||||
- **Correct → No warning**:
|
||||
* **Correct → No warning**:
|
||||
If the operation is fully valid and expected, no message should be issued. The system is working as intended, so no warning is necessary.
|
||||
|
||||
- **Correct but deserves attention → No warning, possibly a log message**:
|
||||
* **Correct but deserves attention → No warning, possibly a log message**:
|
||||
When an operation is correct but uncommon or requires special attention, providing an informational message can be helpful. This keeps users informed without implying any issue. If available, use the logger to output this message. Example:
|
||||
|
||||
```python
|
||||
logger.info("This is an informational message about a rare but correct operation.")
|
||||
```
|
||||
|
||||
- **Correct but very likely a mistake → Warning with option to disable**:
|
||||
* **Correct but very likely a mistake → Warning with option to disable**:
|
||||
In rare cases, you may want to issue a warning for a correct operation that’s very likely a mistake. In such cases, you must provide an option to suppress the warning. This can be done with a flag in the function. Example:
|
||||
|
||||
```python
|
||||
@ -436,7 +406,7 @@ Warnings play a critical role in guiding users toward resolving potential issues
|
||||
# Do something
|
||||
```
|
||||
|
||||
- **Supported but not correct → Warning**:
|
||||
* **Supported but not correct → Warning**:
|
||||
If the operation is technically supported but is deprecated, suboptimal, or could cause future issues (e.g., conflicting arguments), a warning should be raised. This message should be actionable, meaning it must explain how to resolve the issue. Example:
|
||||
|
||||
```python
|
||||
@ -446,7 +416,7 @@ Warnings play a critical role in guiding users toward resolving potential issues
|
||||
# Do something
|
||||
```
|
||||
|
||||
- **Not supported → Exception**:
|
||||
* **Not supported → Exception**:
|
||||
If the operation is invalid or unsupported, raise an exception. This indicates that the operation cannot be performed and requires immediate attention. Example:
|
||||
|
||||
```python
|
||||
|
@ -61,8 +61,6 @@
|
||||
title: Sentiment Tuning
|
||||
- local: using_llama_models
|
||||
title: Training StackLlama
|
||||
- local: detoxifying_a_lm
|
||||
title: Detoxifying a Language Model
|
||||
- local: multi_adapter_rl
|
||||
title: Multi Adapter RLHF
|
||||
title: Examples
|
||||
@ -103,8 +101,6 @@
|
||||
title: Model Classes
|
||||
- local: model_utils
|
||||
title: Model Utilities
|
||||
- local: best_of_n
|
||||
title: Best of N Sampling
|
||||
- local: judges
|
||||
title: Judges
|
||||
- local: callbacks
|
||||
|
@ -1,6 +1,6 @@
|
||||
# BCO Trainer
|
||||
|
||||
[](https://huggingface.co/models?other=bco,trl)
|
||||
[](https://huggingface.co/models?other=bco,trl)
|
||||
|
||||
TRL supports the Binary Classifier Optimization (BCO).
|
||||
The [BCO](https://huggingface.co/papers/2404.04656) authors train a binary classifier whose logit serves as a reward so that the classifier maps {prompt, chosen completion} pairs to 1 and {prompt, rejected completion} pairs to 0.
|
||||
@ -12,17 +12,16 @@ The [`BCOTrainer`] requires an [unpaired preference dataset](dataset_formats#unp
|
||||
The [`BCOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset formats. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
|
||||
|
||||
## Expected model format
|
||||
|
||||
The BCO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that expects `AutoModelForCausalLMWithValueHead` for the value function.
|
||||
|
||||
## Using the `BCOTrainer`
|
||||
|
||||
For a detailed example have a look at the `examples/scripts/bco.py` script. At a high level we need to initialize the `BCOTrainer` with a `model` we wish to train and a reference `ref_model` which we will use to calculate the implicit rewards of the preferred and rejected response.
|
||||
For a detailed example have a look at the `examples/scripts/bco.py` script. At a high level we need to initialize the `BCOTrainer` with a `model` we wish to train and a reference `ref_model` which we will use to calculate the implicit rewards of the preferred and rejected response.
|
||||
|
||||
The `beta` refers to the hyperparameter of the implicit reward, and the dataset contains the 3 entries listed above. Note that the `model` and `ref_model` need to have the same architecture (ie decoder only or encoder-decoder).
|
||||
|
||||
|
||||
|
||||
```py
|
||||
```python
|
||||
training_args = BCOConfig(
|
||||
beta=0.1,
|
||||
)
|
||||
@ -35,9 +34,10 @@ bco_trainer = BCOTrainer(
|
||||
processing_class=tokenizer,
|
||||
)
|
||||
```
|
||||
|
||||
After this one can then call:
|
||||
|
||||
```py
|
||||
```python
|
||||
bco_trainer.train()
|
||||
```
|
||||
|
||||
@ -49,7 +49,7 @@ If the prompts in your desired and undesired datasets differ a lot, it is useful
|
||||
|
||||
Choose an embedding model and tokenizer:
|
||||
|
||||
```py
|
||||
```python
|
||||
embedding_model = AutoModel.from_pretrained(your_model_id)
|
||||
embedding_tokenizer = AutoTokenizer.from_pretrained(your_model_id)
|
||||
|
||||
@ -64,7 +64,7 @@ embedding_func = partial(embed_prompt, model=embedding_model)
|
||||
|
||||
Set `prompt_sample_size` to define how many prompts are selected to train the UDM classifier and start the training with the provided embedding function:
|
||||
|
||||
```py
|
||||
```python
|
||||
training_args = BCOConfig(
|
||||
beta=0.1,
|
||||
prompt_sample_size=512,
|
||||
|
@ -1,74 +0,0 @@
|
||||
# Best of N sampling: Alternative ways to get better model output without RL based fine-tuning
|
||||
|
||||
Within the extras module is the `best-of-n` sampler class that serves as an alternative method of generating better model output.
|
||||
As to how it fares against the RL based fine-tuning, please look in the `examples` directory for a comparison example
|
||||
|
||||
## Usage
|
||||
|
||||
To get started quickly, instantiate an instance of the class with a model, a length sampler, a tokenizer and a callable that serves as a proxy reward pipeline that outputs reward scores for input queries
|
||||
|
||||
```python
|
||||
|
||||
from transformers import pipeline, AutoTokenizer
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
from trl.core import LengthSampler
|
||||
from trl.extras import BestOfNSampler
|
||||
|
||||
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(ref_model_name)
|
||||
reward_pipe = pipeline("sentiment-analysis", model=reward_model, device=device)
|
||||
tokenizer = AutoTokenizer.from_pretrained(ref_model_name)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
|
||||
# callable that takes a list of raw text and returns a list of corresponding reward scores
|
||||
def queries_to_scores(list_of_strings):
|
||||
return [output["score"] for output in reward_pipe(list_of_strings)]
|
||||
|
||||
best_of_n = BestOfNSampler(model, tokenizer, queries_to_scores, length_sampler=output_length_sampler)
|
||||
|
||||
|
||||
```
|
||||
|
||||
And assuming you have a list/tensor of tokenized queries, you can generate better output by calling the `generate` method
|
||||
|
||||
```python
|
||||
|
||||
best_of_n.generate(query_tensors, device=device, **gen_kwargs)
|
||||
|
||||
```
|
||||
The default sample size is 4, but you can change it at the time of instance initialization like so
|
||||
|
||||
```python
|
||||
|
||||
best_of_n = BestOfNSampler(model, tokenizer, queries_to_scores, length_sampler=output_length_sampler, sample_size=8)
|
||||
|
||||
```
|
||||
|
||||
The default output is the result of taking the top scored output for each query, but you can change it to top 2 and so on by passing the `n_candidates` argument at the time of instance initialization
|
||||
|
||||
```python
|
||||
|
||||
best_of_n = BestOfNSampler(model, tokenizer, queries_to_scores, length_sampler=output_length_sampler, n_candidates=2)
|
||||
|
||||
```
|
||||
|
||||
There is the option of setting the generation settings (like `temperature`, `pad_token_id`) at the time of instance creation as opposed to when calling the `generate` method.
|
||||
This is done by passing a `GenerationConfig` from the `transformers` library at the time of initialization
|
||||
|
||||
```python
|
||||
|
||||
from transformers import GenerationConfig
|
||||
|
||||
generation_config = GenerationConfig(min_length= -1, top_k=0.0, top_p= 1.0, do_sample= True, pad_token_id=tokenizer.eos_token_id)
|
||||
|
||||
best_of_n = BestOfNSampler(model, tokenizer, queries_to_scores, length_sampler=output_length_sampler, generation_config=generation_config)
|
||||
|
||||
best_of_n.generate(query_tensors, device=device)
|
||||
|
||||
```
|
||||
|
||||
Furthermore, at the time of initialization you can set the seed to control the repeatability of the generation process and the number of samples to generate for each query
|
||||
|
||||
## BestOfNSampler
|
||||
|
||||
[[autodoc]] BestOfNSampler
|
@ -2,9 +2,11 @@
|
||||
|
||||
TRL provides a powerful command-line interface (CLI) to fine-tune large language models (LLMs) using methods like Supervised Fine-Tuning (SFT), Direct Preference Optimization (DPO), and more. The CLI abstracts away much of the boilerplate, letting you launch training jobs quickly and reproducibly.
|
||||
|
||||
## Commands
|
||||
|
||||
Currently supported commands are:
|
||||
|
||||
#### Training Commands
|
||||
### Training Commands
|
||||
|
||||
- `trl dpo`: fine-tune a LLM with DPO
|
||||
- `trl grpo`: fine-tune a LLM with GRPO
|
||||
@ -13,7 +15,7 @@ Currently supported commands are:
|
||||
- `trl rloo`: fine-tune a LLM with RLOO
|
||||
- `trl sft`: fine-tune a LLM with SFT
|
||||
|
||||
#### Other Commands
|
||||
### Other Commands
|
||||
|
||||
- `trl env`: get the system information
|
||||
- `trl vllm-serve`: serve a model with vLLM
|
||||
@ -197,22 +199,22 @@ trl reward --config reward_config.yaml
|
||||
|
||||
The `--accelerate_config` flag lets you easily configure distributed training with [🤗 Accelerate](https://github.com/huggingface/accelerate). This flag accepts either:
|
||||
|
||||
* the name of a predefined config profile (built into TRL), or
|
||||
* a path to a custom Accelerate YAML config file.
|
||||
- the name of a predefined config profile (built into TRL), or
|
||||
- a path to a custom Accelerate YAML config file.
|
||||
|
||||
#### Predefined Config Profiles
|
||||
|
||||
TRL provides several ready-to-use Accelerate configs to simplify common training setups:
|
||||
|
||||
| Name | Description |
|
||||
| ------------ | ----------------------------------- |
|
||||
| `fsdp1` | Fully Sharded Data Parallel Stage 1 |
|
||||
| `fsdp2` | Fully Sharded Data Parallel Stage 2 |
|
||||
| `zero1` | DeepSpeed ZeRO Stage 1 |
|
||||
| `zero2` | DeepSpeed ZeRO Stage 2 |
|
||||
| `zero3` | DeepSpeed ZeRO Stage 3 |
|
||||
| `multi_gpu` | Multi-GPU training |
|
||||
| `single_gpu` | Single-GPU training |
|
||||
| Name | Description |
|
||||
| --- | --- |
|
||||
| `fsdp1` | Fully Sharded Data Parallel Stage 1 |
|
||||
| `fsdp2` | Fully Sharded Data Parallel Stage 2 |
|
||||
| `zero1` | DeepSpeed ZeRO Stage 1 |
|
||||
| `zero2` | DeepSpeed ZeRO Stage 2 |
|
||||
| `zero3` | DeepSpeed ZeRO Stage 3 |
|
||||
| `multi_gpu` | Multi-GPU training |
|
||||
| `single_gpu` | Single-GPU training |
|
||||
|
||||
To use one of these, just pass the name to `--accelerate_config`. TRL will automatically load the corresponding config file from `trl/accelerate_config/`.
|
||||
|
||||
|
@ -8,6 +8,7 @@ Community tutorials are made by active members of the Hugging Face community who
|
||||
|
||||
| Task | Class | Description | Author | Tutorial | Colab |
|
||||
| --- | --- | --- | --- | --- | --- |
|
||||
| Reinforcement Learning | [`GRPOTrainer`] | Efficient Online Training with GRPO and vLLM in TRL | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/grpo_vllm_online_training) | [](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/grpo_vllm_online_training.ipynb) |
|
||||
| Reinforcement Learning | [`GRPOTrainer`] | Post training an LLM for reasoning with GRPO in TRL | [Sergio Paniego](https://huggingface.co/sergiopaniego) | [Link](https://huggingface.co/learn/cookbook/fine_tuning_llm_grpo_trl) | [](https://colab.research.google.com/github/huggingface/cookbook/blob/main/notebooks/en/fine_tuning_llm_grpo_trl.ipynb) |
|
||||
| Reinforcement Learning | [`GRPOTrainer`] | Mini-R1: Reproduce Deepseek R1 „aha moment“ a RL tutorial | [Philipp Schmid](https://huggingface.co/philschmid) | [Link](https://www.philschmid.de/mini-deepseek-r1) | [](https://colab.research.google.com/github/philschmid/deep-learning-pytorch-huggingface/blob/main/training/mini-deepseek-r1-aha-grpo.ipynb) |
|
||||
| Reinforcement Learning | [`GRPOTrainer`] | RL on LLaMA 3.1-8B with GRPO and Unsloth optimizations | [Andrea Manzoni](https://huggingface.co/AManzoni) | [Link](https://colab.research.google.com/github/amanzoni1/fine_tuning/blob/main/RL_LLama3_1_8B_GRPO.ipynb) | [](https://colab.research.google.com/github/amanzoni1/fine_tuning/blob/main/RL_LLama3_1_8B_GRPO.ipynb) |
|
||||
@ -17,7 +18,6 @@ Community tutorials are made by active members of the Hugging Face community who
|
||||
| Preference Optimization | [`ORPOTrainer`] | Fine-tuning Llama 3 with ORPO combining instruction tuning and preference alignment | [Maxime Labonne](https://huggingface.co/mlabonne) | [Link](https://mlabonne.github.io/blog/posts/2024-04-19_Fine_tune_Llama_3_with_ORPO.html) | [](https://colab.research.google.com/drive/1eHNWg9gnaXErdAa8_mcvjMupbSS6rDvi) |
|
||||
| Instruction tuning | [`SFTTrainer`] | How to fine-tune open LLMs in 2025 with Hugging Face | [Philipp Schmid](https://huggingface.co/philschmid) | [Link](https://www.philschmid.de/fine-tune-llms-in-2025) | [](https://colab.research.google.com/github/philschmid/deep-learning-pytorch-huggingface/blob/main/training/fine-tune-llms-in-2025.ipynb) |
|
||||
|
||||
|
||||
### Videos
|
||||
|
||||
| Task | Title | Author | Video |
|
||||
@ -31,6 +31,7 @@ Community tutorials are made by active members of the Hugging Face community who
|
||||
|
||||
> [!WARNING]
|
||||
> The tutorial uses two deprecated features:
|
||||
>
|
||||
> - `SFTTrainer(..., tokenizer=tokenizer)`: Use `SFTTrainer(..., processing_class=tokenizer)` instead, or simply omit it (it will be inferred from the model).
|
||||
> - `setup_chat_format(model, tokenizer)`: Use `SFTConfig(..., chat_template_path="Qwen/Qwen3-0.6B")`, where `chat_template_path` specifies the model whose chat template you want to copy.
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
# CPO Trainer
|
||||
|
||||
[](https://huggingface.co/models?other=cpo,trl)
|
||||
[](https://huggingface.co/models?other=cpo,trl)
|
||||
|
||||
## Overview
|
||||
|
||||
@ -98,15 +98,13 @@ To use this loss as described in the paper, we can set the `loss_type="alphapo"`
|
||||
|
||||
The CPO algorithm supports several loss functions. The loss function can be set using the `loss_type` parameter in the [`CPOConfig`]. The following loss functions are supported:
|
||||
|
||||
| `loss_type=` | Description |
|
||||
| -------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `"sigmoid"` (default) | Given the preference data, we can fit a binary classifier according to the Bradley-Terry model, and in fact, the [DPO](https://huggingface.co/papers/2305.18290) authors propose the sigmoid loss on the normalized likelihood via the `logsigmoid` to fit a logistic regression. |
|
||||
| `"hinge"` | The [RSO](https://huggingface.co/papers/2309.06657) authors propose to use a hinge loss on the normalized likelihood from the [SLiC](https://huggingface.co/papers/2305.10425) paper. In this case, the `beta` is the reciprocal of the margin. |
|
||||
| `"ipo"` | The [IPO](https://huggingface.co/papers/2310.12036) authors provide a deeper theoretical understanding of the DPO algorithms and identify an issue with overfitting and propose an alternative loss. In this case, the `beta` is the reciprocal of the gap between the log-likelihood ratios of the chosen vs the rejected completion pair, and thus the smaller the `beta`, the larger this gap is. As per the paper, the loss is averaged over log-likelihoods of the completion (unlike DPO, which is summed only). |
|
||||
| `"simpo"` | The [SimPO](https://huggingface.co/papers/2405.14734) method is also implemented in the [`CPOTrainer`]. SimPO is an alternative loss that adds a reward margin, allows for length normalization, and does not use BC regularization. To use this loss, simply set `loss_type="simpo"` and `cpo_alpha=0.0` in the [`CPOConfig`] and `simpo_gamma` to a recommended value. |
|
||||
| `"alphapo"` | The [AlphaPO](https://huggingface.co/papers/2501.03884) method is also implemented in the [`CPOTrainer`]. This is syntactic sugar that automatically sets `loss_type="simpo"` and `cpo_alpha=0.0`. AlphaPO applies a transformation to the reward function shape in the context of SimPO loss when the `alpha` parameter is non-zero. |
|
||||
|
||||
|
||||
| `loss_type=` | Description |
|
||||
| --- | --- |
|
||||
| `"sigmoid"` (default) | Given the preference data, we can fit a binary classifier according to the Bradley-Terry model, and in fact, the [DPO](https://huggingface.co/papers/2305.18290) authors propose the sigmoid loss on the normalized likelihood via the `logsigmoid` to fit a logistic regression. |
|
||||
| `"hinge"` | The [RSO](https://huggingface.co/papers/2309.06657) authors propose to use a hinge loss on the normalized likelihood from the [SLiC](https://huggingface.co/papers/2305.10425) paper. In this case, the `beta` is the reciprocal of the margin. |
|
||||
| `"ipo"` | The [IPO](https://huggingface.co/papers/2310.12036) authors provide a deeper theoretical understanding of the DPO algorithms and identify an issue with overfitting and propose an alternative loss. In this case, the `beta` is the reciprocal of the gap between the log-likelihood ratios of the chosen vs the rejected completion pair, and thus the smaller the `beta`, the larger this gap is. As per the paper, the loss is averaged over log-likelihoods of the completion (unlike DPO, which is summed only). |
|
||||
| `"simpo"` | The [SimPO](https://huggingface.co/papers/2405.14734) method is also implemented in the [`CPOTrainer`]. SimPO is an alternative loss that adds a reward margin, allows for length normalization, and does not use BC regularization. To use this loss, simply set `loss_type="simpo"` and `cpo_alpha=0.0` in the [`CPOConfig`] and `simpo_gamma` to a recommended value. |
|
||||
| `"alphapo"` | The [AlphaPO](https://huggingface.co/papers/2501.03884) method is also implemented in the [`CPOTrainer`]. This is syntactic sugar that automatically sets `loss_type="simpo"` and `cpo_alpha=0.0`. AlphaPO applies a transformation to the reward function shape in the context of SimPO loss when the `alpha` parameter is non-zero. |
|
||||
|
||||
### For Mixture of Experts Models: Enabling the auxiliary loss
|
||||
|
||||
|
@ -2,8 +2,6 @@
|
||||
|
||||
TRL is designed with modularity in mind so that users are able to efficiently customize the training loop for their needs. Below are some examples on how you can apply and test different techniques. Note: Although these examples use the DPOTrainer, the customization applies to most (if not all) trainers.
|
||||
|
||||
|
||||
|
||||
## Use different optimizers and schedulers
|
||||
|
||||
By default, the `DPOTrainer` creates a `torch.optim.AdamW` optimizer. You can create and define a different optimizer and pass it to `DPOTrainer` as follows:
|
||||
@ -84,11 +82,11 @@ trainer = DPOTrainer(
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
## Pass 8-bit reference models
|
||||
|
||||
## Pass 8-bit reference models
|
||||
|
||||
Since `trl` supports all keyword arguments when loading a model from `transformers` using `from_pretrained`, you can also leverage `load_in_8bit` from `transformers` for more memory efficient fine-tuning.
|
||||
|
||||
Read more about 8-bit model loading in `transformers` [here](https://huggingface.co/docs/transformers/en/peft#load-in-8bit-or-4bit).
|
||||
Read more about 8-bit model loading in `transformers` [Load in 8bit or 4bit](https://huggingface.co/docs/transformers/en/peft#load-in-8bit-or-4bit).
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
|
@ -81,7 +81,7 @@ This guide provides an overview of the dataset formats and types supported by ea
|
||||
<td>Stepwise supervision</td>
|
||||
<td>
|
||||
<pre><code>{"prompt": "Which number is larger, 9.8 or 9.11?",
|
||||
"completions": ["The fractional part of 9.8 is 0.8.",
|
||||
"completions": ["The fractional part of 9.8 is 0.8.",
|
||||
"The fractional part of 9.11 is 0.11.",
|
||||
"0.11 is greater than 0.8.",
|
||||
"Hence, 9.11 > 9.8."],
|
||||
@ -387,23 +387,23 @@ For examples of stepwise supervision datasets, refer to the [Stepwise supervisio
|
||||
|
||||
Choosing the right dataset type depends on the task you are working on and the specific requirements of the TRL trainer you are using. Below is a brief overview of the dataset types supported by each TRL trainer.
|
||||
|
||||
| Trainer | Expected dataset type |
|
||||
| ----------------------- | ------------------------------------------------------------------------------------------------------ |
|
||||
| [`BCOTrainer`] | [Unpaired preference](#unpaired-preference) or [Preference (explicit prompt recommended)](#preference) |
|
||||
| [`CPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
|
||||
| [`DPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
|
||||
| [`GKDTrainer`] | [Prompt-completion](#prompt-completion) |
|
||||
| [`GRPOTrainer`] | [Prompt-only](#prompt-only) |
|
||||
| [`KTOTrainer`] | [Unpaired preference](#unpaired-preference) or [Preference (explicit prompt recommended)](#preference) |
|
||||
| [`NashMDTrainer`] | [Prompt-only](#prompt-only) |
|
||||
| [`OnlineDPOTrainer`] | [Prompt-only](#prompt-only) |
|
||||
| [`ORPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
|
||||
| [`PPOTrainer`] | Tokenized language modeling |
|
||||
| [`PRMTrainer`] | [Stepwise supervision](#stepwise-supervision) |
|
||||
| [`RewardTrainer`] | [Preference (implicit prompt recommended)](#preference) |
|
||||
| [`RLOOTrainer`] | [Prompt-only](#prompt-only) |
|
||||
| [`SFTTrainer`] | [Language modeling](#language-modeling) or [Prompt-completion](#prompt-completion) |
|
||||
| [`XPOTrainer`] | [Prompt-only](#prompt-only) |
|
||||
| Trainer | Expected dataset type |
|
||||
| --- | --- |
|
||||
| [`BCOTrainer`] | [Unpaired preference](#unpaired-preference) or [Preference (explicit prompt recommended)](#preference) |
|
||||
| [`CPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
|
||||
| [`DPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
|
||||
| [`GKDTrainer`] | [Prompt-completion](#prompt-completion) |
|
||||
| [`GRPOTrainer`] | [Prompt-only](#prompt-only) |
|
||||
| [`KTOTrainer`] | [Unpaired preference](#unpaired-preference) or [Preference (explicit prompt recommended)](#preference) |
|
||||
| [`NashMDTrainer`] | [Prompt-only](#prompt-only) |
|
||||
| [`OnlineDPOTrainer`] | [Prompt-only](#prompt-only) |
|
||||
| [`ORPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
|
||||
| [`PPOTrainer`] | Tokenized language modeling |
|
||||
| [`PRMTrainer`] | [Stepwise supervision](#stepwise-supervision) |
|
||||
| [`RewardTrainer`] | [Preference (implicit prompt recommended)](#preference) |
|
||||
| [`RLOOTrainer`] | [Prompt-only](#prompt-only) |
|
||||
| [`SFTTrainer`] | [Language modeling](#language-modeling) or [Prompt-completion](#prompt-completion) |
|
||||
| [`XPOTrainer`] | [Prompt-only](#prompt-only) |
|
||||
|
||||
> [!TIP]
|
||||
> TRL trainers only support standard dataset formats, [for now](https://github.com/huggingface/trl/issues/2071). If you have a conversational dataset, you must first convert it into a standard format.
|
||||
@ -416,7 +416,7 @@ Fortunately, TRL offers tools to easily handle this conversion, which are detail
|
||||
|
||||
### Converting a conversational dataset into a standard dataset
|
||||
|
||||
To convert a conversational dataset into a standard dataset, you need to _apply a chat template_ to the dataset. A chat template is a predefined structure that typically includes placeholders for user and assistant messages. This template is provided by the tokenizer of the model you use.
|
||||
To convert a conversational dataset into a standard dataset, you need to *apply a chat template* to the dataset. A chat template is a predefined structure that typically includes placeholders for user and assistant messages. This template is provided by the tokenizer of the model you use.
|
||||
|
||||
For detailed instructions on using chat templating, refer to the [Chat templating section in the `transformers` documentation](https://huggingface.co/docs/transformers/en/chat_templating).
|
||||
|
||||
@ -519,15 +519,15 @@ This section provides example code to help you convert between different dataset
|
||||
|
||||
For simplicity, some of the examples below do not follow this recommendation and use the standard format. However, the conversions can be applied directly to the conversational format without modification.
|
||||
|
||||
| From \ To | Language modeling | Prompt-completion | Prompt-only | Preference with implicit prompt | Preference | Unpaired preference | Stepwise supervision |
|
||||
| ------------------------------- | ----------------------------------------------------------------------- | ----------------------------------------------------------------------- | ----------------------------------------------------------------- | --------------------------------------------------------- | --------------------------------------------------------- | ------------------------------------------------------------------------- | -------------------- |
|
||||
| Language modeling | N/A | N/A | N/A | N/A | N/A | N/A | N/A |
|
||||
| Prompt-completion | [🔗](#from-prompt-completion-to-language-modeling-dataset) | N/A | [🔗](#from-prompt-completion-to-prompt-only-dataset) | N/A | N/A | N/A | N/A |
|
||||
| Prompt-only | N/A | N/A | N/A | N/A | N/A | N/A | N/A |
|
||||
| Preference with implicit prompt | [🔗](#from-preference-with-implicit-prompt-to-language-modeling-dataset) | [🔗](#from-preference-with-implicit-prompt-to-prompt-completion-dataset) | [🔗](#from-preference-with-implicit-prompt-to-prompt-only-dataset) | N/A | [🔗](#from-implicit-to-explicit-prompt-preference-dataset) | [🔗](#from-preference-with-implicit-prompt-to-unpaired-preference-dataset) | N/A |
|
||||
| Preference | [🔗](#from-preference-to-language-modeling-dataset) | [🔗](#from-preference-to-prompt-completion-dataset) | [🔗](#from-preference-to-prompt-only-dataset) | [🔗](#from-explicit-to-implicit-prompt-preference-dataset) | N/A | [🔗](#from-preference-to-unpaired-preference-dataset) | N/A |
|
||||
| Unpaired preference | [🔗](#from-unpaired-preference-to-language-modeling-dataset) | [🔗](#from-unpaired-preference-to-prompt-completion-dataset) | [🔗](#from-unpaired-preference-to-prompt-only-dataset) | N/A | N/A | N/A | N/A |
|
||||
| Stepwise supervision | [🔗](#from-stepwise-supervision-to-language-modeling-dataset) | [🔗](#from-stepwise-supervision-to-prompt-completion-dataset) | [🔗](#from-stepwise-supervision-to-prompt-only-dataset) | N/A | N/A | [🔗](#from-stepwise-supervision-to-unpaired-preference-dataset) | N/A |
|
||||
| From \ To | Language modeling | Prompt-completion | Prompt-only | Preference with implicit prompt | Preference | Unpaired preference | Stepwise supervision |
|
||||
| --- | --- | --- | --- | --- | --- | --- | --- |
|
||||
| Language modeling | N/A | N/A | N/A | N/A | N/A | N/A | N/A |
|
||||
| Prompt-completion | [🔗](#from-prompt-completion-to-language-modeling-dataset) | N/A | [🔗](#from-prompt-completion-to-prompt-only-dataset) | N/A | N/A | N/A | N/A |
|
||||
| Prompt-only | N/A | N/A | N/A | N/A | N/A | N/A | N/A |
|
||||
| Preference with implicit prompt | [🔗](#from-preference-with-implicit-prompt-to-language-modeling-dataset) | [🔗](#from-preference-with-implicit-prompt-to-prompt-completion-dataset) | [🔗](#from-preference-with-implicit-prompt-to-prompt-only-dataset) | N/A | [🔗](#from-implicit-to-explicit-prompt-preference-dataset) | [🔗](#from-preference-with-implicit-prompt-to-unpaired-preference-dataset) | N/A |
|
||||
| Preference | [🔗](#from-preference-to-language-modeling-dataset) | [🔗](#from-preference-to-prompt-completion-dataset) | [🔗](#from-preference-to-prompt-only-dataset) | [🔗](#from-explicit-to-implicit-prompt-preference-dataset) | N/A | [🔗](#from-preference-to-unpaired-preference-dataset) | N/A |
|
||||
| Unpaired preference | [🔗](#from-unpaired-preference-to-language-modeling-dataset) | [🔗](#from-unpaired-preference-to-prompt-completion-dataset) | [🔗](#from-unpaired-preference-to-prompt-only-dataset) | N/A | N/A | N/A | N/A |
|
||||
| Stepwise supervision | [🔗](#from-stepwise-supervision-to-language-modeling-dataset) | [🔗](#from-stepwise-supervision-to-prompt-completion-dataset) | [🔗](#from-stepwise-supervision-to-prompt-only-dataset) | N/A | N/A | [🔗](#from-stepwise-supervision-to-unpaired-preference-dataset) | N/A |
|
||||
|
||||
### From prompt-completion to language modeling dataset
|
||||
|
||||
|
@ -1,187 +0,0 @@
|
||||
# Detoxifying a Language Model using PPO
|
||||
|
||||
Language models (LMs) are known to sometimes generate toxic outputs. In this example, we will show how to "detoxify" a LM by feeding it toxic prompts and then using [Transformer Reinforcement Learning (TRL)](https://huggingface.co/docs/trl/index) and Proximal Policy Optimization (PPO) to "detoxify" it.
|
||||
|
||||
Read this section to follow our investigation on how we can reduce toxicity in a wide range of LMs, from 125m parameters to 6B parameters!
|
||||
|
||||
Here's an overview of the notebooks and scripts in the [TRL toxicity repository](https://github.com/huggingface/trl/tree/main/examples/toxicity/scripts) as well as the link for the interactive demo:
|
||||
|
||||
| File | Description | Colab link |
|
||||
|---|---| --- |
|
||||
| [`gpt-j-6b-toxicity.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/toxicity/scripts/gpt-j-6b-toxicity.py) | Detoxify `GPT-J-6B` using PPO | x |
|
||||
| [`evaluate-toxicity.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/toxicity/scripts/evaluate-toxicity.py) | Evaluate de-toxified models using `evaluate` | x |
|
||||
| [Interactive Space](https://huggingface.co/spaces/ybelkada/detoxified-lms)| An interactive Space that you can use to compare the original model with its detoxified version!| x |
|
||||
|
||||
## Context
|
||||
|
||||
Language models are trained on large volumes of text from the internet which also includes a lot of toxic content. Naturally, language models pick up the toxic patterns during training. Especially when prompted with already toxic texts the models are likely to continue the generations in a toxic way. The goal here is to "force" the model to be less toxic by feeding it toxic prompts and then using PPO to "detoxify" it.
|
||||
|
||||
### Computing toxicity scores
|
||||
|
||||
In order to optimize a model with PPO we need to define a reward. For this use-case we want a negative reward whenever the model generates something toxic and a positive comment when it is not toxic.
|
||||
Therefore, we used [`facebook/roberta-hate-speech-dynabench-r4-target`](https://huggingface.co/facebook/roberta-hate-speech-dynabench-r4-target), which is a RoBERTa model fine-tuned to classify between "neutral" and "toxic" text as our toxic prompts classifier.
|
||||
One could have also used different techniques to evaluate the toxicity of a model, or combined different toxicity classifiers, but for simplicity we have chosen to use this one.
|
||||
|
||||
### Selection of models
|
||||
|
||||
We selected the following models for our experiments to show that TRL can be easily scaled to 10B parameters models:
|
||||
|
||||
* [`EleutherAI/gpt-neo-125M`](https://huggingface.co/EleutherAI/gpt-neo-125M) (125 million parameters)
|
||||
* [`EleutherAI/gpt-neo-2.7B`](https://huggingface.co/EleutherAI/gpt-neo-2.7B) (2.7 billion parameters)
|
||||
* [`EleutherAI/gpt-j-6B`](https://huggingface.co/EleutherAI/gpt-j-6B) (6 billion parameters)
|
||||
|
||||
For the selection of the smallest model, we have chosen `EleutherAI/gpt-neo-125M` because it has shown to be a model that was the "most toxic" compared to other models. We have run toxicity evaluation using `facebook/roberta-hate-speech-dynabench-r4-target` model on 4 different architectures on a subset of `allenai/real-toxicity-prompts` dataset. Note that we have computed the toxicity score on the generated text only (thus ignoring the prompt).
|
||||
|
||||
| Model | Mean toxicity score |
|
||||
|---|---|
|
||||
| `gpt2` | 0.01602 |
|
||||
| `facebook/opt-350m` | 0.01628 |
|
||||
| `bigscience/bloom-560m` | 0.00767 |
|
||||
| `EleutherAI/gpt-neo-125M` | **0.02016** |
|
||||
|
||||
## Designing the problem
|
||||
|
||||
When doing PPO, it is very important to design the problem efficiently so that the model can learn to solve it. Let's cover the topics that were important for the model to converge.
|
||||
|
||||
### Pre-processing the dataset
|
||||
|
||||
The dataset consists of prompts and their continuations, and each of them has an associated `toxicity` score.
|
||||
|
||||
A `prompt` example:
|
||||
```
|
||||
{ "text": "Bronx officer Luis Rios alleged on-the-job tall tales were compounded by his lying about them in court, and", "profanity": 0.08360514, "sexually_explicit": 0.118823394, "identity_attack": 0.09608547, "flirtation": 0.23102224, "threat": 0.13215046, "insult": 0.10130461, "severe_toxicity": 0.04068885, "toxicity": 0.16534281 }
|
||||
```
|
||||
And its `continuation` value:
|
||||
```
|
||||
{ "text": " cost federal prosecutors all the drug evidence that the cop collected against an armed suspect — 16 baggies of cocaine during a strip search.", "severe_toxicity": 0.067997746, "toxicity": 0.1694093, "profanity": 0.11931301, "sexually_explicit": 0.12521537, "identity_attack": 0.09268324, "flirtation": 0.13452998, "threat": 0.31312028, "insult": 0.10761123 }
|
||||
```
|
||||
|
||||
We want to increase the chance for the model to generate toxic prompts so we get more learning signal. For this reason pre-process the dataset to consider only the prompt that has a toxicity score that is greater than a threshold. We can do this in a few lines of code:
|
||||
```python
|
||||
train_dataset = load_dataset("allenai/real-toxicity-prompts", split="train")
|
||||
|
||||
def filter_fn(sample):
|
||||
toxicity = sample["prompt"]["toxicity"]
|
||||
return toxicity is not None and toxicity > 0.3
|
||||
|
||||
train_dataset = train_dataset.filter(filter_fn, batched=False)
|
||||
```
|
||||
|
||||
### Reward function
|
||||
|
||||
The reward function is one of the most important part of training a model with reinforcement learning. It is the function that will tell the model if it is doing well or not.
|
||||
We tried various combinations, considering the softmax of the label "neutral", the log of the toxicity score and the raw logits of the label "neutral". We have found out that the convergence was much more smoother with the raw logits of the label "neutral".
|
||||
```python
|
||||
logits = toxicity_model(**toxicity_inputs).logits.float()
|
||||
rewards = (logits[:, 0]).tolist()
|
||||
```
|
||||
|
||||
### Impact of input prompts length
|
||||
|
||||
We have found out that training a model with small or long context (from 5 to 8 tokens for the small context and from 15 to 20 tokens for the long context) does not have any impact on the convergence of the model, however, when training the model with longer prompts, the model will tend to generate more toxic prompts.
|
||||
As a compromise between the two we took for a context window of 10 to 15 tokens for the training.
|
||||
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl-long-vs-short-context.png">
|
||||
</div>
|
||||
|
||||
### How to deal with OOM issues
|
||||
|
||||
Our goal is to train models up to 6B parameters, which is about 24GB in float32! Here are two tricks we use to be able to train a 6B model on a single 40GB-RAM GPU:
|
||||
|
||||
- Use `bfloat16` precision: Simply load your model in `bfloat16` when calling `from_pretrained` and you can reduce the size of the model by 2:
|
||||
|
||||
```python
|
||||
model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", dtype=torch.bfloat16)
|
||||
```
|
||||
|
||||
and the optimizer will take care of computing the gradients in `bfloat16` precision. Note that this is a pure `bfloat16` training which is different from the mixed precision training. If one wants to train a model in mixed-precision, they should not load the model with `dtype` and specify the mixed precision argument when calling `accelerate config`.
|
||||
|
||||
- Use shared layers: Since PPO algorithm requires to have both the active and reference model to be on the same device, we have decided to use shared layers to reduce the memory footprint of the model. This can be achieved by specifying `num_shared_layers` argument when calling the `create_reference_model()` function. For example, if you want to share the first 6 layers of the model, you can do it like this:
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl-shared-layers.png">
|
||||
</div>
|
||||
|
||||
```python
|
||||
ref_model = create_reference_model(model, num_shared_layers=6)
|
||||
trainer = PPOTrainer(..., ref_model=ref_model)
|
||||
```
|
||||
|
||||
In the example above this means that the model has the 4 first layers frozen (i.e. since these layers are shared between the active model and the reference model).
|
||||
|
||||
- One could have also applied gradient checkpointing to reduce the memory footprint of the model by calling `model.pretrained_model.enable_gradient_checkpointing()` (although this has the downside of training being ~20% slower).
|
||||
|
||||
## Training the model!
|
||||
|
||||
We have decided to keep 3 models in total that correspond to our best models:
|
||||
|
||||
- [`ybelkada/gpt-neo-125m-detox`](https://huggingface.co/ybelkada/gpt-neo-125m-detox)
|
||||
- [`ybelkada/gpt-neo-2.7B-detox`](https://huggingface.co/ybelkada/gpt-neo-2.7B-detox)
|
||||
- [`ybelkada/gpt-j-6b-detox`](https://huggingface.co/ybelkada/gpt-j-6b-detox)
|
||||
|
||||
We have used different learning rates for each model, and have found out that the largest models were quite hard to train and can easily lead to collapse mode if the learning rate is not chosen correctly (i.e. if the learning rate is too high):
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl-collapse-mode.png">
|
||||
</div>
|
||||
|
||||
The final training run of `ybelkada/gpt-j-6b-detoxified-20shdl` looks like this:
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl-gpt-j-final-run-2.png">
|
||||
</div>
|
||||
|
||||
As you can see the model converges nicely, but obviously we don't observe a very large improvement from the first step, as the original model is not trained to generate toxic contents.
|
||||
|
||||
Also we have observed that training with larger `mini_batch_size` leads to smoother convergence and better results on the test set:
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl-gpt-j-mbs-run.png">
|
||||
</div>
|
||||
|
||||
## Results
|
||||
|
||||
We tested our models on a new dataset, the [`OxAISH-AL-LLM/wiki_toxic`](https://huggingface.co/datasets/OxAISH-AL-LLM/wiki_toxic) dataset. We feed each model with a toxic prompt from it (a sample with the label "toxic"), and generate 30 new tokens as it is done on the training loop and measure the toxicity score using `evaluate`'s [`toxicity` metric](https://huggingface.co/spaces/ybelkada/toxicity).
|
||||
We report the toxicity score of 400 sampled examples, compute its mean and standard deviation and report the results in the table below:
|
||||
|
||||
| Model | Mean toxicity score | Std toxicity score |
|
||||
| --- | --- | --- |
|
||||
| `EleutherAI/gpt-neo-125m` | 0.1627 | 0.2997 |
|
||||
| `ybelkada/gpt-neo-125m-detox` | **0.1148** | **0.2506** |
|
||||
| --- | --- | --- |
|
||||
| `EleutherAI/gpt-neo-2.7B` | 0.1884 | 0.3178 |
|
||||
| `ybelkada/gpt-neo-2.7B-detox` | **0.0916** | **0.2104** |
|
||||
| --- | --- | --- |
|
||||
| `EleutherAI/gpt-j-6B` | 0.1699 | 0.3033 |
|
||||
| `ybelkada/gpt-j-6b-detox` | **0.1510** | **0.2798** |
|
||||
|
||||
<div class="column" style="text-align:center">
|
||||
<figure>
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl-final-barplot.png" style="width:80%">
|
||||
<figcaption>Toxicity score with respect to the size of the model.</figcaption>
|
||||
</figure>
|
||||
</div>
|
||||
|
||||
Below are few generation examples of `gpt-j-6b-detox` model:
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl-toxicity-examples.png">
|
||||
</div>
|
||||
|
||||
The evaluation script can be found [here](https://github.com/huggingface/trl/blob/main/examples/research_projects/toxicity/scripts/evaluate-toxicity.py).
|
||||
|
||||
### Discussions
|
||||
|
||||
The results are quite promising, as we can see that the models are able to reduce the toxicity score of the generated text by an interesting margin. The gap is clear for `gpt-neo-2B` model but we see less so for the `gpt-j-6B` model. There are several things we could try to improve the results on the largest model starting with training with larger `mini_batch_size` and probably allowing to back-propagate through more layers (i.e. use less shared layers).
|
||||
|
||||
To sum up, in addition to human feedback this could be a useful additional signal when training large language models to ensure their outputs are less toxic as well as useful.
|
||||
|
||||
### Limitations
|
||||
|
||||
We are also aware of consistent bias issues reported with toxicity classifiers, and of work evaluating the negative impact of toxicity reduction on the diversity of outcomes. We recommend that future work also compare the outputs of the detoxified models in terms of fairness and diversity before putting them to use.
|
||||
|
||||
## What is next?
|
||||
|
||||
You can download the model and use it out of the box with `transformers`, or play with the Spaces that compares the output of the models before and after detoxification [here](https://huggingface.co/spaces/ybelkada/detoxified-lms).
|
@ -26,11 +26,12 @@ accelerate launch --config_file examples/accelerate_configs/multi_gpu.yaml train
|
||||
This automatically distributes the workload across all available GPUs.
|
||||
|
||||
Under the hood, [🤗 Accelerate](https://github.com/huggingface/accelerate) creates one model per GPU. Each process:
|
||||
|
||||
- Processes its own batch of data
|
||||
- Computes the loss and gradients for that batch
|
||||
- Shares gradient updates across all GPUs
|
||||
|
||||

|
||||

|
||||
|
||||
The effective batch size is calculated as:
|
||||
|
||||
@ -177,8 +178,7 @@ These results show that **Context Parallelism (CP) scales effectively with more
|
||||
>
|
||||
> You can learn more and explore configuration examples in the [Accelerate ND-parallelism guide](https://github.com/huggingface/accelerate/blob/main/examples/torch_native_parallelism/README.md#nd-parallelism).
|
||||
|
||||
|
||||
**Further Reading on Context Parallelism**
|
||||
### Further Reading on Context Parallelism
|
||||
|
||||
- [Accelerate: Context Parallelism Guide](https://github.com/huggingface/accelerate/blob/main/docs/source/concept_guides/context_parallelism.md)
|
||||
- [Accelerate Example: 128k Sequence Length](https://github.com/huggingface/accelerate/blob/main/examples/torch_native_parallelism/README.md#context-parallelism-128k-sequence-length)
|
||||
@ -187,4 +187,4 @@ These results show that **Context Parallelism (CP) scales effectively with more
|
||||
|
||||
## Multi-Node Training
|
||||
|
||||
We're working on a guide for multi-node training. Stay tuned! 🚀
|
||||
We're working on a guide for multi-node training. Stay tuned! 🚀
|
||||
|
@ -1,6 +1,6 @@
|
||||
# DPO Trainer
|
||||
|
||||
[](https://huggingface.co/models?other=dpo,trl) [](https://github.com/huggingface/smol-course/tree/main/2_preference_alignment)
|
||||
[](https://huggingface.co/models?other=dpo,trl) [](https://github.com/huggingface/smol-course/tree/main/2_preference_alignment)
|
||||
|
||||
## Overview
|
||||
|
||||
@ -19,7 +19,7 @@ Then, fine-tuning a language model via DPO consists of two steps and is easier t
|
||||
|
||||
This process is illustrated in the sketch below (from [Figure 1 of the DPO paper](https://huggingface.co/papers/2305.18290)):
|
||||
|
||||

|
||||

|
||||
|
||||
Read more about DPO algorithm in the [original paper](https://huggingface.co/papers/2305.18290).
|
||||
|
||||
@ -101,7 +101,6 @@ Additionally, unlike standard text-based models where a `tokenizer` is used, for
|
||||
|
||||
For a complete example of fine-tuning a vision-language model, refer to the script in [`examples/scripts/dpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo_vlm.py).
|
||||
|
||||
|
||||
## Example script
|
||||
|
||||
We provide an example script to train a model using the DPO method. The script is available in [`trl/scripts/dpo.py`](https://github.com/huggingface/trl/blob/main/trl/scripts/dpo.py)
|
||||
@ -192,10 +191,10 @@ To scale how much the auxiliary loss contributes to the total loss, use the hype
|
||||
|
||||
You can further accelerate QLoRA / LoRA (2x faster, 60% less memory) using the [`unsloth`](https://github.com/unslothai/unsloth) library that is fully compatible with `SFTTrainer`. Currently `unsloth` supports only Llama (Yi, TinyLlama, Qwen, Deepseek etc) and Mistral architectures. Some benchmarks for DPO listed below:
|
||||
|
||||
| GPU | Model | Dataset | 🤗 | 🤗 + FlashAttention 2 | 🦥 Unsloth | 🦥 VRAM saved |
|
||||
| -------- | --------- | ---------- | --- | --------------------- | --------- | ------------ |
|
||||
| A100 40G | Zephyr 7b | Ultra Chat | 1x | 1.24x | **1.88x** | -11.6% |
|
||||
| Tesla T4 | Zephyr 7b | Ultra Chat | 1x | 1.09x | **1.55x** | -18.6% |
|
||||
| GPU | Model | Dataset | 🤗 | 🤗 + FlashAttention 2 | 🦥 Unsloth | 🦥 VRAM saved |
|
||||
| --- | --- | --- | --- | --- | --- | --- |
|
||||
| A100 40G | Zephyr 7b | Ultra Chat | 1x | 1.24x | **1.88x** | -11.6% |
|
||||
| Tesla T4 | Zephyr 7b | Ultra Chat | 1x | 1.09x | **1.55x** | -18.6% |
|
||||
|
||||
First install `unsloth` according to the [official documentation](https://github.com/unslothai/unsloth). Once installed, you can incorporate unsloth into your workflow in a very simple manner; instead of loading `AutoModelForCausalLM`, you just need to load a `FastLanguageModel` as follows:
|
||||
|
||||
|
@ -1,16 +1,15 @@
|
||||
# Examples
|
||||
|
||||
|
||||
## Introduction
|
||||
|
||||
The examples should work in any of the following settings (with the same script):
|
||||
- single GPU
|
||||
- multi GPUs (using PyTorch distributed mode)
|
||||
- multi GPUs (using DeepSpeed ZeRO-Offload stages 1, 2, & 3)
|
||||
- fp16 (mixed-precision), fp32 (normal precision), or bf16 (bfloat16 precision)
|
||||
|
||||
To run it in each of these various modes, first initialize the accelerate
|
||||
configuration with `accelerate config`
|
||||
- single GPU
|
||||
- multi GPUs (using PyTorch distributed mode)
|
||||
- multi GPUs (using DeepSpeed ZeRO-Offload stages 1, 2, & 3)
|
||||
- fp16 (mixed-precision), fp32 (normal precision), or bf16 (bfloat16 precision)
|
||||
|
||||
To run it in each of these various modes, first initialize the accelerate configuration with `accelerate config`.
|
||||
|
||||
To train with a 4-bit or 8-bit model, please run:
|
||||
|
||||
@ -28,7 +27,6 @@ accelerate config # will prompt you to define the training configuration
|
||||
|
||||
Then, it is encouraged to launch jobs with `accelerate launch`!
|
||||
|
||||
|
||||
## Maintained Examples
|
||||
|
||||
Scripts can be used as examples of how to use TRL trainers. They are located in the [`trl/scripts`](https://github.com/huggingface/trl/blob/main/trl/scripts) directory. Additionally, we provide examples in the [`examples/scripts`](https://github.com/huggingface/trl/blob/main/examples/scripts) directory. These examples are maintained and tested regularly.
|
||||
@ -42,9 +40,9 @@ Scripts can be used as examples of how to use TRL trainers. They are located in
|
||||
| [`examples/scripts/evals/judge_tldr.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/evals/judge_tldr.py) | This script shows how to use [`HfPairwiseJudge`] or [`OpenAIPairwiseJudge`] to judge model generations. |
|
||||
| [`examples/scripts/gkd.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/gkd.py) | This script shows how to use the [`GKDTrainer`] to fine-tune a model. |
|
||||
| [`trl/scripts/grpo.py`](https://github.com/huggingface/trl/blob/main/trl/scripts/grpo.py) | This script shows how to use the [`GRPOTrainer`] to fine-tune a model. |
|
||||
| [`examples/scripts/grpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/grpo_vlm.py) | This script shows how to use the [`GRPOTrainer`] to fine-tune a multimodal model for reasoning using the [lmms-lab/multimodal-open-r1-8k-verified](https://huggingface.co/datasets/lmms-lab/multimodal-open-r1-8k-verified) dataset. |
|
||||
| [`examples/scripts/gspo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/gspo.py) | This script shows how to use GSPO via the [`GRPOTrainer`] to fine-tune model for reasoning using the [AI-MO/NuminaMath-TIR](https://huggingface.co/datasets/AI-MO/NuminaMath-TIR) dataset. |
|
||||
| [`examples/scripts/gspo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/gspo_vlm.py) | This script shows how to use GSPO via the [`GRPOTrainer`] to fine-tune a multimodal model for reasoning using the [lmms-lab/multimodal-open-r1-8k-verified](https://huggingface.co/datasets/lmms-lab/multimodal-open-r1-8k-verified) dataset. |
|
||||
| [`examples/scripts/grpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/grpo_vlm.py) | This script shows how to use the [`GRPOTrainer`] to fine-tune a multimodal model for reasoning using the [lmms-lab/multimodal-open-r1-8k-verified](https://huggingface.co/datasets/lmms-lab/multimodal-open-r1-8k-verified) dataset. |
|
||||
| [`examples/scripts/gspo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/gspo.py) | This script shows how to use GSPO via the [`GRPOTrainer`] to fine-tune model for reasoning using the [AI-MO/NuminaMath-TIR](https://huggingface.co/datasets/AI-MO/NuminaMath-TIR) dataset. |
|
||||
| [`examples/scripts/gspo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/gspo_vlm.py) | This script shows how to use GSPO via the [`GRPOTrainer`] to fine-tune a multimodal model for reasoning using the [lmms-lab/multimodal-open-r1-8k-verified](https://huggingface.co/datasets/lmms-lab/multimodal-open-r1-8k-verified) dataset. |
|
||||
| [`examples/scripts/kto.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/kto.py) | This script shows how to use the [`KTOTrainer`] to fine-tune a model. |
|
||||
| [`examples/scripts/mpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/mpo_vlm.py) | This script shows how to use MPO via the [`DPOTrainer`] to align a model based on preferences using the [HuggingFaceH4/rlaif-v_formatted](https://huggingface.co/datasets/HuggingFaceH4/rlaif-v_formatted) dataset and a set of loss weights with weights. |
|
||||
| [`examples/scripts/nash_md.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/nash_md.py) | This script shows how to use the [`NashMDTrainer`] to fine-tune a model. |
|
||||
@ -68,15 +66,9 @@ Here are also some easier-to-run colab notebooks that you can use to get started
|
||||
|
||||
| File | Description |
|
||||
| --- | --- |
|
||||
| [`examples/notebooks/best_of_n.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/best_of_n.ipynb) | This notebook demonstrates how to use the "Best of N" sampling strategy using TRL when fine-tuning your model with PPO. |
|
||||
| [`examples/notebooks/gpt2-sentiment.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-sentiment.ipynb) | This notebook demonstrates how to reproduce the GPT2 imdb sentiment tuning example on a jupyter notebook. |
|
||||
| [`examples/notebooks/gpt2-control.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-control.ipynb) | This notebook demonstrates how to reproduce the GPT2 sentiment control example on a jupyter notebook. |
|
||||
|
||||
|
||||
We also have some other examples that are less maintained but can be used as a reference:
|
||||
1. **[research_projects](https://github.com/huggingface/trl/tree/main/examples/research_projects)**: Check out this folder to find the scripts used for some research projects that used TRL (LM de-toxification, Stack-Llama, etc.)
|
||||
|
||||
|
||||
## Distributed training
|
||||
|
||||
All the scripts can be run on multiple GPUs by providing the path of an 🤗 Accelerate config file when calling `accelerate launch`. To launch one of them on one or multiple GPUs, run the following command (swapping `{NUM_GPUS}` with the number of GPUs in your machine and `--all_arguments_of_the_script` with your arguments).
|
||||
|
@ -1,17 +1,17 @@
|
||||
# Generalized Knowledge Distillation Trainer
|
||||
|
||||
[](https://huggingface.co/models?other=gkd,trl)
|
||||
[](https://huggingface.co/models?other=gkd,trl)
|
||||
|
||||
## Overview
|
||||
|
||||
Generalized Knowledge Distillation (GKD) was proposed in [On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes](https://huggingface.co/papers/2306.13649) by Rishabh Agarwal, Nino Vieillard, Yongchao Zhou, Piotr Stanczyk, Sabela Ramos, Matthieu Geist, and Olivier Bachem.
|
||||
Generalized Knowledge Distillation (GKD) was proposed in [On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes](https://huggingface.co/papers/2306.13649) by Rishabh Agarwal, Nino Vieillard, Yongchao Zhou, Piotr Stanczyk, Sabela Ramos, Matthieu Geist, and Olivier Bachem.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
> Knowledge distillation (KD) is widely used for compressing a teacher model to reduce its inference cost and memory footprint, by training a smaller student model. However, current KD methods for auto-regressive sequence models suffer from distribution mismatch between output sequences seen during training and those generated by the student during inference. To address this issue, we introduce Generalized Knowledge Distillation (GKD). Instead of solely relying on a fixed set of output sequences, GKD trains the student on its self-generated output sequences by leveraging feedback from the teacher on such sequences. Unlike supervised KD approaches, GKD also offers the flexibility to employ alternative loss functions between the student and teacher, which can be useful when the student lacks the expressivity to mimic the teacher's distribution. Furthermore, GKD facilitates the seamless integration of distillation with RL fine-tuning (RLHF). We demonstrate the efficacy of GKD for distilling auto-regressive language models on summarization, translation, and arithmetic reasoning tasks, and task-agnostic distillation for instruction-tuning.
|
||||
|
||||
|
||||
The key aspects of GKD are:
|
||||
|
||||
1. It addresses the train-inference distribution mismatch in auto-regressive sequence models by training the student model on its self-generated output sequences.
|
||||
2. GKD allows flexibility in choosing different divergence measures between student and teacher models via the generalized Jensen-Shannon Divergence (JSD), which can be useful when the student lacks the capacity to fully mimic the teacher.
|
||||
|
||||
@ -20,6 +20,7 @@ This post-training method was contributed by [Kashif Rasul](https://huggingface.
|
||||
## Usage tips
|
||||
|
||||
The [`GKDTrainer`] is a wrapper around the [`SFTTrainer`] class that takes in a teacher model argument. It needs three parameters to be set via the [`GKDConfig`] namely:
|
||||
|
||||
* `lmbda`: controls the student data fraction, i.e., the proportion of on-policy student-generated outputs. When `lmbda=0.0`, the loss reduces to supervised JSD where the student is trained with the token-level probabilities of the teacher. When `lmbda=1.0`, the loss reduces to on-policy JSD, where the student generates output sequences and token-specific feedback on these sequences from the teacher. For values in between [0, 1] it is random between the two based on the `lmbda` value for each batch.
|
||||
* `seq_kd`: controls whether to perform Sequence-Level KD (can be viewed as supervised FT on teacher-generated out). When `seq_kd=True` and `lmbda=0.0`, the loss reduces to supervised JSD, where the teacher generates output sequences and the student receives token-specific feedback on these sequences from the teacher.
|
||||
* `beta`: controls the interpolation in the generalized Jensen-Shannon Divergence. When `beta=0.0` the loss approximates forward KL divergence, while for `beta=1.0` the loss approximates reverse KL divergence. For values in between [0, 1] it interpolates between the two.
|
||||
@ -85,6 +86,7 @@ trainer.train()
|
||||
### Expected dataset type
|
||||
|
||||
The dataset should be formatted as a list of "messages" where each message is a list of dictionaries with the following keys:
|
||||
|
||||
* `role`: either `system`, `assistant` or `user`
|
||||
* `content`: the message content
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
# GRPO Trainer
|
||||
|
||||
[](https://huggingface.co/models?other=grpo,trl)
|
||||
[](https://huggingface.co/models?other=grpo,trl)
|
||||
|
||||
## Overview
|
||||
|
||||
@ -56,13 +56,13 @@ accelerate launch train_grpo.py
|
||||
|
||||
Distributed across 8 GPUs, the training takes approximately 1 day.
|
||||
|
||||

|
||||

|
||||
|
||||
## Looking deeper into the GRPO method
|
||||
|
||||
GRPO is an online learning algorithm, meaning it improves iteratively by using the data generated by the trained model itself during training. The intuition behind GRPO objective is to maximize the advantage of the generated completions, while ensuring that the model remains close to the reference policy. To understand how GRPO works, it can be broken down into four main steps: **Generating completions**, **computing the advantage**, **estimating the KL divergence**, and **computing the loss**.
|
||||
|
||||

|
||||

|
||||
|
||||
### Generating completions
|
||||
|
||||
@ -80,7 +80,6 @@ This approach gives the method its name: **Group Relative Policy Optimization (G
|
||||
> It was shown in the paper [Understanding R1-Zero-Like Training: A Critical Perspective](https://huggingface.co/papers/2503.20783) that scaling by \\( \text{std}(\mathbf{r}) \\) may cause a question-level difficulty bias. You can disable this scaling by setting `scale_rewards=False` in [`GRPOConfig`].
|
||||
|
||||
> [!TIP]
|
||||
>
|
||||
> [Part I: Tricks or Traps? A Deep Dive into RL for LLM Reasoning (Lite PPO)](https://huggingface.co/papers/2508.08221) showed that calculating the mean at the local (group) level and the standard deviation at the global (batch) level enables more robust reward shaping. You can use this scaling strategy by setting `scale_rewards="batch"` in [`GRPOConfig`].
|
||||
|
||||
### Estimating the KL divergence
|
||||
@ -167,10 +166,10 @@ While training and evaluating, we record the following reward metrics:
|
||||
- `entropy`: Average entropy of token predictions across generated completions. (If `mask_truncated_completions=True`, masked sequences tokens are excluded.)
|
||||
- `kl`: The average KL divergence between the model and the reference model, calculated over generated completions. Logged only if `beta` is nonzero.
|
||||
- `clip_ratio/region_mean`: The ratio of token (or sequence, if `importance_sampling_level="sequence"`) probabilities where the GRPO objective is clipped to stay within the trust region:
|
||||
$$
|
||||
\text{clip}\left( r_{i,t}(\theta), 1 - \epsilon_\mathrm{low}, 1 + \epsilon_\mathrm{high} \right)\,, \qquad r_{i,t}(\theta) = \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,< t})}\,.
|
||||
$$
|
||||
A higher value means more tokens are clipped, which constrains how much the policy $\pi_\theta$ can change.
|
||||
$$
|
||||
\text{clip}\left( r_{i,t}(\theta), 1 - \epsilon_\mathrm{low}, 1 + \epsilon_\mathrm{high} \right)\,, \qquad r_{i,t}(\theta) = \frac{\pi_\theta(o_{i,t} \mid q, o_{i,< t})}{\pi_{\theta_{\text{old}}}(o_{i,t} \mid q, o_{i,< t})}\,.
|
||||
$$
|
||||
A higher value means more tokens are clipped, which constrains how much the policy $\pi_\theta$ can change.
|
||||
- `clip_ratio/low_mean`: The average ratio of token (or sequence, if `importance_sampling_level="sequence"`) probabilities that were clipped on the lower bound of the trust region: \\(r_{i,t}(\theta) < 1 - \epsilon_\mathrm{low}\\)
|
||||
- `clip_ratio/low_min`: The minimum ratio of token (or sequence, if `importance_sampling_level="sequence"`) probabilities that were clipped on the lower bound of the trust region: \\(r_{i,t}(\theta) < 1 - \epsilon_\mathrm{low}\\)
|
||||
- `clip_ratio/high_mean`: The average ratio of token (or sequence, if `importance_sampling_level="sequence"`) probabilities that were clipped on the upper bound of the trust region: \\(r_{i,t}(\theta) > 1 + \epsilon_\mathrm{high}\\)
|
||||
@ -181,6 +180,7 @@ A higher value means more tokens are clipped, which constrains how much the poli
|
||||
### Speed up training with vLLM-powered generation
|
||||
|
||||
Generation is often the main bottleneck when training with online methods. To accelerate generation, you can use [vLLM](https://github.com/vllm-project/vllm), a high-throughput, low-latency inference engine for LLMs. To enable it, first install the package with
|
||||
|
||||
```shell
|
||||
pip install trl[vllm]
|
||||
```
|
||||
@ -195,11 +195,13 @@ We support two ways of using vLLM during training: **server mode** and **colocat
|
||||
In this mode, vLLM runs in a separate process (and using separate GPUs) and communicates with the trainer via HTTP. This is ideal if you have dedicated GPUs for inference.
|
||||
|
||||
1. **Start the vLLM server**:
|
||||
|
||||
```bash
|
||||
trl vllm-serve --model <model_name>
|
||||
```
|
||||
|
||||
2. **Enable server mode in your training script**:
|
||||
|
||||
```python
|
||||
from trl import GRPOConfig
|
||||
|
||||
@ -232,12 +234,7 @@ training_args = GRPOConfig(
|
||||
>
|
||||
> We provide a [HF Space](https://huggingface.co/spaces/trl-lib/recommend-vllm-memory) to help estimate the recommended GPU memory utilization based on your model configuration and experiment settings. Simply use it as follows to get `vllm_gpu_memory_utilization` recommendation:
|
||||
>
|
||||
> <iframe
|
||||
> src="https://trl-lib-recommend-vllm-memory.hf.space"
|
||||
> frameborder="0"
|
||||
> width="850"
|
||||
> height="450"
|
||||
> ></iframe>
|
||||
> <iframe src="https://trl-lib-recommend-vllm-memory.hf.space" frameborder="0" width="850" height="450"></iframe>
|
||||
>
|
||||
> If the recommended value does not work in your environment, we suggest adding a small buffer (e.g., +0.05 or +0.1) to the recommended value to ensure stability.
|
||||
>
|
||||
@ -436,6 +433,7 @@ You can test this function as follows:
|
||||
>>> reward_func(prompts=prompts, completions=completions, ground_truth=ground_truth)
|
||||
[1.0, 0.0]
|
||||
```
|
||||
|
||||
#### Example 4: Multi-task reward functions
|
||||
|
||||
Below is an example of using multiple reward functions in the [`GRPOTrainer`]. In this example, we define two task-specific reward functions: `math_reward_func` and `coding_reward_func`. The `math_reward_func` rewards math problems based on their correctness, while the `coding_reward_func` rewards coding problems based on whether the solution works.
|
||||
@ -496,8 +494,6 @@ In this example, the `math_reward_func` and `coding_reward_func` are designed to
|
||||
|
||||
Note that the [`GRPOTrainer`] will ignore the `None` rewards returned by the reward functions and only consider the rewards returned by the relevant functions. This ensures that the model is trained on the relevant tasks and ignores the tasks for which there is no relevant reward function.
|
||||
|
||||
|
||||
|
||||
#### Passing the reward function to the trainer
|
||||
|
||||
To use your custom reward function, pass it to the [`GRPOTrainer`] as follows:
|
||||
|
@ -9,11 +9,13 @@ The library is integrated with 🤗 [transformers](https://github.com/huggingfac
|
||||
|
||||
Below is the current list of TRL trainers, organized by method type (⚡️ = vLLM support).
|
||||
|
||||
<div style="display: flex; justify-content: space-between; width: 100%; gap: 2rem;">
|
||||
## Taxonomy
|
||||
|
||||
<div style="display: flex; justify-content: space-between; width: 100%; gap: 2rem;">
|
||||
<div style="flex: 1; min-width: 0;">
|
||||
|
||||
**Online methods**
|
||||
### Online methods
|
||||
|
||||
- [`GRPOTrainer`] ⚡️
|
||||
- [`RLOOTrainer`] ⚡️
|
||||
- [`OnlineDPOTrainer`] ⚡️
|
||||
@ -21,15 +23,16 @@ Below is the current list of TRL trainers, organized by method type (⚡️ = vL
|
||||
- [`XPOTrainer`] ⚡️
|
||||
- [`PPOTrainer`]
|
||||
|
||||
**Reward modeling**
|
||||
### Reward modeling
|
||||
|
||||
- [`PRMTrainer`]
|
||||
- [`RewardTrainer`]
|
||||
|
||||
</div>
|
||||
|
||||
<div style="flex: 1; min-width: 0;">
|
||||
|
||||
**Offline methods**
|
||||
### Offline methods
|
||||
|
||||
- [`SFTTrainer`]
|
||||
- [`DPOTrainer`]
|
||||
- [`ORPOTrainer`]
|
||||
@ -37,14 +40,13 @@ Below is the current list of TRL trainers, organized by method type (⚡️ = vL
|
||||
- [`CPOTrainer`]
|
||||
- [`KTOTrainer`]
|
||||
|
||||
**Knowledge distillation**
|
||||
### Knowledge distillation
|
||||
|
||||
- [`GKDTrainer`]
|
||||
|
||||
</div>
|
||||
|
||||
</div>
|
||||
|
||||
|
||||
## 🎉 What's New
|
||||
|
||||
**✨ OpenAI GPT OSS Support**: TRL now fully supports fine-tuning the latest [OpenAI GPT OSS models](https://huggingface.co/collections/openai/gpt-oss-68911959590a1634ba11c7a4)! Check out the:
|
||||
|
@ -1,13 +1,15 @@
|
||||
# Installation
|
||||
|
||||
You can install TRL either from PyPI or from source:
|
||||
|
||||
## PyPI
|
||||
|
||||
Install the library with pip or [uv](https://docs.astral.sh/uv/):
|
||||
|
||||
<hfoptions id="install">
|
||||
<hfoption id="uv">
|
||||
|
||||
uv is a fast Rust-based Python package and project manager. Refer to [Installation](https://docs.astral.sh/uv/getting-started/installation/) for installation instructions).
|
||||
uv is a fast Rust-based Python package and project manager. Refer to [Installation](https://docs.astral.sh/uv/getting-started/installation/) for installation instructions.
|
||||
|
||||
```bash
|
||||
uv pip install trl
|
||||
@ -24,6 +26,7 @@ pip install trl
|
||||
</hfoptions>
|
||||
|
||||
## Source
|
||||
|
||||
You can also install the latest version from source. First clone the repo and then run the installation with `pip`:
|
||||
|
||||
```bash
|
||||
|
@ -1,6 +1,6 @@
|
||||
# Training with Jobs
|
||||
|
||||
[](https://huggingface.co/models?other=hf_jobs,trl)
|
||||
[](https://huggingface.co/models?other=hf_jobs,trl)
|
||||
|
||||
[Hugging Face Jobs](https://huggingface.co/docs/huggingface_hub/guides/jobs) lets you run training scripts on fully managed infrastructure—no need to manage GPUs or local environment setup.
|
||||
|
||||
|
@ -46,7 +46,6 @@ trl sft ... --attn_implementation kernels-community/flash-attn
|
||||
> [!TIP]
|
||||
> Now you can leverage faster attention backends with a pre-optimized kernel for your hardware configuration from the Hub, speeding up both development and training.
|
||||
|
||||
|
||||
## Comparing Attention Implementations
|
||||
|
||||
We evaluated various attention implementations available in transformers, along with different kernel backends, using **TRL** and **SFT**.
|
||||
@ -54,15 +53,14 @@ The experiments were run on a single **H100 GPU** with **CUDA 12.9**, leveraging
|
||||
Keep in mind that the results shown here are specific to this setup and may vary with different training configurations.
|
||||
|
||||
The following figure illustrates both **latency** (time per training step) and **peak allocated memory** for the different attention implementations and kernel backends.
|
||||
Kernel-based implementations perform on par with custom-installed attention, and increasing the model’s `max_length` further enhances performance. Memory consumption is similar across all implementations, showing no significant differences. We get the same performance but with less friction, as described in [the following section](#benchmarking-flash-attention-build-from-source-vs-hub-kernels).
|
||||
|
||||
Kernel-based implementations perform on par with custom-installed attention, and increasing the model’s `max_length` further enhances performance. Memory consumption is similar across all implementations, showing no significant differences. We get the same performance but with less friction, as described in [the following section](#flash-attention-vs-hub-kernels).
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/kernels_guide_latency.png" alt="Latency and Memory Usage" width="45%"/>
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/kernels_guide_peak_allocated_memory.png" alt="Latency and Memory Usage" width="45%"/>
|
||||
</div>
|
||||
|
||||
## Flash Attention (Build-from-Source) vs. Hub Kernels
|
||||
## Flash Attention vs. Hub Kernels
|
||||
|
||||
Building Flash Attention from source can be time-consuming, often taking anywhere from several minutes to hours, depending on your hardware, CUDA/PyTorch configuration, and whether precompiled wheels are available.
|
||||
|
||||
@ -74,7 +72,6 @@ You can combine **FlashAttention kernels** with **Liger kernels** for additional
|
||||
|
||||
First, install the Liger kernel dependency:
|
||||
|
||||
|
||||
```bash
|
||||
pip install liger-kernel
|
||||
```
|
||||
@ -96,6 +93,4 @@ training_args = SFTConfig(
|
||||
)
|
||||
```
|
||||
|
||||
Learn more about this integration [here](./liger_kernel_integration).
|
||||
|
||||
|
||||
Learn more about the [Liger Kernel Integration](./liger_kernel_integration).
|
||||
|
@ -1,12 +1,11 @@
|
||||
# KTO Trainer
|
||||
|
||||
[](https://huggingface.co/models?other=kto,trl)
|
||||
[](https://huggingface.co/models?other=kto,trl)
|
||||
|
||||
## Overview
|
||||
|
||||
Kahneman-Tversky Optimization (KTO) was introduced in [KTO: Model Alignment as Prospect Theoretic Optimization](https://huggingface.co/papers/2402.01306) by [Kawin Ethayarajh](https://huggingface.co/kawine), [Winnie Xu](https://huggingface.co/xwinxu), [Niklas Muennighoff](https://huggingface.co/Muennighoff), Dan Jurafsky, [Douwe Kiela](https://huggingface.co/douwekiela).
|
||||
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
> Kahneman & Tversky's prospect theory tells us that humans perceive random variables in a biased but well-defined manner; for example, humans are famously loss-averse. We show that objectives for aligning LLMs with human feedback implicitly incorporate many of these biases -- the success of these objectives (e.g., DPO) over cross-entropy minimization can partly be ascribed to them being human-aware loss functions (HALOs). However, the utility functions these methods attribute to humans still differ from those in the prospect theory literature. Using a Kahneman-Tversky model of human utility, we propose a HALO that directly maximizes the utility of generations instead of maximizing the log-likelihood of preferences, as current methods do. We call this approach Kahneman-Tversky Optimization (KTO), and it matches or exceeds the performance of preference-based methods at scales from 1B to 30B. Crucially, KTO does not need preferences -- only a binary signal of whether an output is desirable or undesirable for a given input. This makes it far easier to use in the real world, where preference data is scarce and expensive.
|
||||
@ -51,7 +50,7 @@ accelerate launch train_kto.py
|
||||
|
||||
Distributed across 8 x H100 GPUs, the training takes approximately 30 minutes. You can verify the training progress by checking the reward graph. An increasing trend in the reward margin indicates that the model is improving and generating better responses over time.
|
||||
|
||||

|
||||

|
||||
|
||||
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-KTO) performs, you can use the [Transformers Chat CLI](https://huggingface.co/docs/transformers/quicktour#chat-with-text-generation-models).
|
||||
|
||||
@ -60,14 +59,14 @@ To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-KTO) pe
|
||||
What is the best programming language?
|
||||
|
||||
<strong><span style="color: blue;"><trl-lib/Qwen2-0.5B-KTO>:</span></strong>
|
||||
The best programming language can vary depending on individual preferences, industry-specific requirements, technical skills, and familiarity with the specific use case or task. Here are some widely-used programming languages that have been noted as popular and widely used:
|
||||
The best programming language can vary depending on individual preferences, industry-specific requirements, technical skills, and familiarity with the specific use case or task. Here are some widely-used programming languages that have been noted as popular and widely used:
|
||||
|
||||
Here are some other factors to consider when choosing a programming language for a project:
|
||||
|
||||
<strong><span style="color: green;">1</span> JavaScript</strong>: JavaScript is at the heart of the web and can be used for building web applications, APIs, and interactive front-end applications like frameworks like React and Angular. It's similar to C, C++, and F# in syntax structure and is accessible and easy to learn, making it a popular choice for beginners and professionals alike.
|
||||
<strong><span style="color: green;">2</span> Java</strong>: Known for its object-oriented programming (OOP) and support for Java 8 and .NET, Java is used for developing enterprise-level software applications, high-performance games, as well as mobile apps, game development, and desktop applications.
|
||||
<strong><span style="color: green;">3</span> C++</strong>: Known for its flexibility and scalability, C++ offers comprehensive object-oriented programming and is a popular choice for high-performance computing and other technical fields. It's a powerful platform for building real-world applications and games at scale.
|
||||
<strong><span style="color: green;">4</span> Python</strong>: Developed by Guido van Rossum in 1991, Python is a high-level, interpreted, and dynamically typed language known for its simplicity, readability, and versatility.
|
||||
<strong><span style="color: green;">1</span> JavaScript</strong>: JavaScript is at the heart of the web and can be used for building web applications, APIs, and interactive front-end applications like frameworks like React and Angular. It's similar to C, C++, and F# in syntax structure and is accessible and easy to learn, making it a popular choice for beginners and professionals alike.
|
||||
<strong><span style="color: green;">2</span> Java</strong>: Known for its object-oriented programming (OOP) and support for Java 8 and .NET, Java is used for developing enterprise-level software applications, high-performance games, as well as mobile apps, game development, and desktop applications.
|
||||
<strong><span style="color: green;">3</span> C++</strong>: Known for its flexibility and scalability, C++ offers comprehensive object-oriented programming and is a popular choice for high-performance computing and other technical fields. It's a powerful platform for building real-world applications and games at scale.
|
||||
<strong><span style="color: green;">4</span> Python</strong>: Developed by Guido van Rossum in 1991, Python is a high-level, interpreted, and dynamically typed language known for its simplicity, readability, and versatility.
|
||||
</code></pre>
|
||||
|
||||
## Expected dataset format
|
||||
@ -102,7 +101,6 @@ To ensure that we train MOEs similarly during preference-tuning, it is beneficia
|
||||
This option is enabled by setting `output_router_logits=True` in the model config (e.g. [`~transformers.MixtralConfig`]).
|
||||
To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: `0.001`) in the model config.
|
||||
|
||||
|
||||
### Batch size recommendations
|
||||
|
||||
Use a per-step batch size that is at least 4, and an effective batch size between 16 and 128. Even if your effective batch size is large, if your per-step batch size is poor, then the KL estimate in KTO will be poor.
|
||||
|
@ -7,15 +7,15 @@
|
||||
|
||||
With this memory reduction, you can potentially turn off `cpu_offloading` or gradient checkpointing to further boost the performance.
|
||||
|
||||
| Speed Up | Memory Reduction |
|
||||
|--------------------------|-------------------------|
|
||||
| Speed Up | Memory Reduction |
|
||||
| --- | --- |
|
||||
|  |  |
|
||||
|
||||
1. To use Liger-Kernel in [`SFTTrainer`], first install it by:
|
||||
|
||||
```bash
|
||||
pip install liger-kernel
|
||||
```
|
||||
|
||||
```bash
|
||||
pip install liger-kernel
|
||||
```
|
||||
|
||||
2. Once installed, set `use_liger_kernel` in [`SFTConfig`]. No other changes are needed!
|
||||
|
||||
|
@ -44,6 +44,7 @@ Here's a brief explanation for the logged metrics provided in the data:
|
||||
* `episode`: The current episode count in the training process.
|
||||
|
||||
### Crucial values
|
||||
|
||||
During training, many values are logged, here are the most important ones:
|
||||
|
||||
1. `objective/scores`: The mean scores returned by the reward model / environment.
|
||||
@ -63,7 +64,7 @@ Here's a brief explanation for the logged metrics provided in the data for the G
|
||||
|
||||
* `num_tokens`: Total number of input tokens processed during training so far.
|
||||
|
||||
#### Completions
|
||||
### Completions
|
||||
|
||||
* `completions/mean_length`: Mean length of all generated completions (including those not ending with an EOS token).
|
||||
* `completions/min_length`: Minimum length among all generated completions.
|
||||
@ -73,34 +74,33 @@ Here's a brief explanation for the logged metrics provided in the data for the G
|
||||
* `completions/min_terminated_length`: Minimum length among completions that ended with an EOS token.
|
||||
* `completions/max_terminated_length`: Maximum length among completions that ended with an EOS token.
|
||||
|
||||
#### Rewards
|
||||
### Rewards
|
||||
|
||||
* `rewards/{reward_func_name}/mean`: The mean reward obtained from a specific, named reward function (e.g., `rewards/my_custom_reward/mean`). This is logged for each reward function used.
|
||||
* `rewards/{reward_func_name}/std`: The standard deviation of rewards from a specific, named reward function.
|
||||
* `reward`: The overall mean of the (potentially weighted and, if `args.scale_rewards` is true, normalized) rewards, after group-wise normalization (advantages).
|
||||
* `reward_std`: The standard deviation of the (potentially weighted) rewards *before* group-wise normalization for advantages.
|
||||
|
||||
#### Policy and Loss Metrics
|
||||
### Policy and Loss Metrics
|
||||
|
||||
* `kl`: The mean Kullback-Leibler (KL) divergence between the current policy and the reference policy. This is logged only if `beta` (the KL coefficient in `GRPOConfig`) is non-zero.
|
||||
* `entropy`: Average entropy of token predictions across generated completions.
|
||||
* If Liger GRPOLoss is used (`use_liger_loss: True` in `GRPOConfig`):
|
||||
* `clip_ratio`: The fraction of policy updates where the probability ratio was clipped according to the GRPO loss's epsilon bounds.
|
||||
* `clip_ratio`: The fraction of policy updates where the probability ratio was clipped according to the GRPO loss's epsilon bounds.
|
||||
* If standard GRPOLoss is used (`use_liger_loss: False`):
|
||||
* `clip_ratio/low_mean`: The mean fraction of instances where the probability ratio `r_t(θ)` was clipped at the lower bound `1 - epsilon_low` (occurs when advantage is negative and ratio is below the bound).
|
||||
* `clip_ratio/low_min`: The minimum observed fraction for `clip_ratio/low_mean` across batches/processes.
|
||||
* `clip_ratio/high_mean`: The mean fraction of instances where the probability ratio `r_t(θ)` was clipped at the upper bound `1 + epsilon_high` (occurs when advantage is positive and ratio is above the bound).
|
||||
* `clip_ratio/high_max`: The maximum observed fraction for `clip_ratio/high_mean` across batches/processes.
|
||||
* `clip_ratio/region_mean`: The mean fraction of instances where the probability ratio was clipped at either the lower or upper bound.
|
||||
* `clip_ratio/low_mean`: The mean fraction of instances where the probability ratio `r_t(θ)` was clipped at the lower bound `1 - epsilon_low` (occurs when advantage is negative and ratio is below the bound).
|
||||
* `clip_ratio/low_min`: The minimum observed fraction for `clip_ratio/low_mean` across batches/processes.
|
||||
* `clip_ratio/high_mean`: The mean fraction of instances where the probability ratio `r_t(θ)` was clipped at the upper bound `1 + epsilon_high` (occurs when advantage is positive and ratio is above the bound).
|
||||
* `clip_ratio/high_max`: The maximum observed fraction for `clip_ratio/high_mean` across batches/processes.
|
||||
* `clip_ratio/region_mean`: The mean fraction of instances where the probability ratio was clipped at either the lower or upper bound.
|
||||
|
||||
### Crucial GRPO values
|
||||
|
||||
During GRPO training, monitor these values for insights into performance and stability:
|
||||
|
||||
1. `reward`: This is the primary objective. It reflects the (group-wise normalized) rewards the policy is achieving. It should generally increase during successful training.
|
||||
1. `kl`: If `beta > 0`, this tracks the divergence from the reference model. Keep an eye on it to ensure the policy doesn't stray too far, which can lead to instability.
|
||||
1. `clip_ratio/*` (either `clip_ratio` for Liger loss or the more detailed `clip_ratio/...` metrics for standard loss): These indicate how often the policy updates are being constrained by the GRPO clipping mechanism. Very high values might suggest that the policy is trying to change too drastically (potentially due to large advantages or a learning rate that's too high) or that the epsilon clipping range is too restrictive.
|
||||
1. `completions/clipped_ratio`: A high ratio here indicates that the model is frequently generating completions that are cut off by `max_completion_length` rather than naturally ending with an EOS token. This might suggest issues with learning sequence termination or that `max_completion_length` is too short.
|
||||
1. `rewards/{reward_func_name}/mean`: Monitoring the mean of individual reward functions can help diagnose which aspects of the desired behavior the model is learning or struggling with, especially when using multiple reward sources.
|
||||
1. `entropy`: Measures how uncertain the policy is in its action choices, higher entropy suggests more exploration. A collapse in entropy means the policy is becoming overconfident and deterministic, often too early. This can stall learning by reducing exploration and making updates overly biased. Stable but non-zero entropy is usually a sign that the policy retains flexibility and continues to explore.
|
||||
|
||||
* `reward`: This is the primary objective. It reflects the (group-wise normalized) rewards the policy is achieving. It should generally increase during successful training.
|
||||
* `kl`: If `beta > 0`, this tracks the divergence from the reference model. Keep an eye on it to ensure the policy doesn't stray too far, which can lead to instability.
|
||||
* `clip_ratio/*` (either `clip_ratio` for Liger loss or the more detailed `clip_ratio/...` metrics for standard loss): These indicate how often the policy updates are being constrained by the GRPO clipping mechanism. Very high values might suggest that the policy is trying to change too drastically (potentially due to large advantages or a learning rate that's too high) or that the epsilon clipping range is too restrictive.
|
||||
* `completions/clipped_ratio`: A high ratio here indicates that the model is frequently generating completions that are cut off by `max_completion_length` rather than naturally ending with an EOS token. This might suggest issues with learning sequence termination or that `max_completion_length` is too short.
|
||||
* `rewards/{reward_func_name}/mean`: Monitoring the mean of individual reward functions can help diagnose which aspects of the desired behavior the model is learning or struggling with, especially when using multiple reward sources.
|
||||
* `entropy`: Measures how uncertain the policy is in its action choices, higher entropy suggests more exploration. A collapse in entropy means the policy is becoming overconfident and deterministic, often too early. This can stall learning by reducing exploration and making updates overly biased. Stable but non-zero entropy is usually a sign that the policy retains flexibility and continues to explore.
|
||||
|
@ -22,14 +22,13 @@ Let's implement and train LoRA adapters in TRL scripts based on the core finding
|
||||
The blog post performs SFT on a range of models and datasets from the Hub, which we can reproduce in TRL.
|
||||
|
||||
| Model | Dataset |
|
||||
|-------|---------|
|
||||
| --- | --- |
|
||||
| [Llama-3.2-1B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-1B) | [allenai/tulu-3-sft-mixture](https://huggingface.co/datasets/allenai/tulu-3-sft-mixture) |
|
||||
| [Llama-3.2-1B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-1B) | [open-thoughts/OpenThoughts-114k](https://huggingface.co/datasets/open-thoughts/OpenThoughts-114k) |
|
||||
| [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B) | [allenai/tulu-3-sft-mixture](https://huggingface.co/datasets/allenai/tulu-3-sft-mixture) |
|
||||
| [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B) | [open-thoughts/OpenThoughts-114k](https://huggingface.co/datasets/open-thoughts/OpenThoughts-114k) |
|
||||
|
||||
<hfoptions id="sft">
|
||||
|
||||
<hfoption id="python">
|
||||
|
||||
We can integrate these findings with the TRL Python API like so:
|
||||
@ -64,7 +63,6 @@ trainer.train()
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
|
||||
<hfoption id="jobs">
|
||||
|
||||
```bash
|
||||
@ -127,10 +125,10 @@ Once training starts, you can monitor the progress in [Trackio](https://huggingf
|
||||
|
||||
### Reinforcement Learning (GRPO)
|
||||
|
||||
The blog post performs GRPO on a range of models and datasets from the Hub, and once again we can reproduce the results in TRL.
|
||||
The blog post performs GRPO on a range of models and datasets from the Hub, and once again we can reproduce the results in TRL.
|
||||
|
||||
| Model | Dataset |
|
||||
|-------|---------|
|
||||
| --- | --- |
|
||||
| [Llama-3.1-8B-Base](https://huggingface.co/meta-llama/Llama-3.2-1B) | [GSM8k](https://huggingface.co/datasets/openai/gsm8k) |
|
||||
| [Llama-3.1-8B-Base](https://huggingface.co/meta-llama/Llama-3.2-1B) | [DeepMath-103K](https://huggingface.co/datasets/zwhe99/DeepMath-103K) |
|
||||
| [Qwen3-8b-base](https://huggingface.co/Qwen/Qwen3-8b-base) | [DeepMath-103K](https://huggingface.co/datasets/zwhe99/DeepMath-103K) |
|
||||
@ -226,7 +224,6 @@ def strip_reasoning_accuracy_reward(
|
||||
</details>
|
||||
|
||||
<hfoptions id="grpo">
|
||||
|
||||
<hfoption id="python">
|
||||
|
||||
We can implement these recommendations with the TRL Python API like so:
|
||||
@ -276,7 +273,6 @@ trainer.train()
|
||||
> This snippet skips the reward function which is defined above to keep the example concise.
|
||||
|
||||
</hfoption>
|
||||
|
||||
<hfoption id="jobs">
|
||||
|
||||
```bash
|
||||
@ -321,7 +317,6 @@ To use Hugging Face Jobs, you will need to be logged in to the Hugging Face Hub
|
||||
<hfoption id="local">
|
||||
|
||||
```bash
|
||||
|
||||
uv run "https://huggingface.co/datasets/burtenshaw/lora-without-regrets/resolve/main/grpo.py" \
|
||||
--model_name_or_path Qwen/Qwen3-0.6B \
|
||||
--dataset_name HuggingFaceH4/OpenR1-Math-220k-default-verified \
|
||||
@ -372,23 +367,23 @@ And most importantly, the LoRA model uses significantly less memory than the ful
|
||||
|
||||
Here are the parameters we used to train the above models
|
||||
|
||||
| Parameter | LoRA | Full FT |
|
||||
|----------------------------------|----------------------------------------------------|-------------------------------|
|
||||
| `--model_name_or_path` | HuggingFaceTB/SmolLM3-3B | HuggingFaceTB/SmolLM3-3B |
|
||||
| `--dataset_name` | HuggingFaceH4/OpenR1-Math-220k-default-verified | HuggingFaceH4/OpenR1-Math-220k-default-verified |
|
||||
| `--learning_rate` | 1.0e-5 | 1.0e-6 |
|
||||
| `--max_prompt_length` | 1024 | 1024 |
|
||||
| `--max_completion_length` | 4096 | 4096 |
|
||||
| `--lora_r` | 1 | - |
|
||||
| `--lora_alpha` | 32 | - |
|
||||
| `--lora_dropout` | 0.0 | - |
|
||||
| `--lora_target_modules` | all-linear | - |
|
||||
| Parameter | LoRA | Full FT |
|
||||
| --- | --- | --- |
|
||||
| `--model_name_or_path` | HuggingFaceTB/SmolLM3-3B | HuggingFaceTB/SmolLM3-3B |
|
||||
| `--dataset_name` | HuggingFaceH4/OpenR1-Math-220k-default-verified | HuggingFaceH4/OpenR1-Math-220k-default-verified |
|
||||
| `--learning_rate` | 1.0e-5 | 1.0e-6 |
|
||||
| `--max_prompt_length` | 1024 | 1024 |
|
||||
| `--max_completion_length` | 4096 | 4096 |
|
||||
| `--lora_r` | 1 | - |
|
||||
| `--lora_alpha` | 32 | - |
|
||||
| `--lora_dropout` | 0.0 | - |
|
||||
| `--lora_target_modules` | all-linear | - |
|
||||
|
||||
Let's break down the key findings of the blog post and how we were able to reproduce them.
|
||||
|
||||
### 1. *LoRA performs better when applied to all weight matrices*
|
||||
|
||||
The authors recommend applying LoRA to all weight matrices rather than limiting it to attention layers, as increasing the rank does not compensate for this restriction.
|
||||
The authors recommend applying LoRA to all weight matrices rather than limiting it to attention layers, as increasing the rank does not compensate for this restriction.
|
||||
|
||||

|
||||
|
||||
@ -402,7 +397,7 @@ peft_config = LoraConfig(target_modules="all-linear")
|
||||
|
||||
### 2. *The adapter needs sufficient capacity to learn from the dataset*
|
||||
|
||||
The blog post recommends using a sufficient LoRA rank to learn from the dataset. The rank determines the number of trainable parameters in the LoRA adapter. Therefore, "For datasets that exceed LoRA capacity, LoRA underperforms FullFT".
|
||||
The blog post recommends using a sufficient LoRA rank to learn from the dataset. The rank determines the number of trainable parameters in the LoRA adapter. Therefore, "For datasets that exceed LoRA capacity, LoRA underperforms FullFT".
|
||||
|
||||

|
||||
|
||||
@ -413,7 +408,7 @@ Reinforcement learning tasks typically require lower capacity, so smaller LoRA r
|
||||
The blog post defines the ideal dataset size for LoRA to match full fine-tuning as "Post-training scale". Which we can use to determine the recommended rank for SFT and RL LoRAs as:
|
||||
|
||||
| Task Type | Dataset Size | Recommended Rank |
|
||||
|-----------|-------------|------------------|
|
||||
| --- | --- | --- |
|
||||
| **SFT** | Post-training scale | 256 |
|
||||
| **RL** | Any size | 1-32 |
|
||||
|
||||
|
@ -8,7 +8,6 @@ With the `AutoModelForCausalLMWithValueHead` class TRL supports all decoder mode
|
||||
|
||||
## AutoModelForCausalLMWithValueHead
|
||||
|
||||
|
||||
[[autodoc]] AutoModelForCausalLMWithValueHead
|
||||
- __init__
|
||||
- forward
|
||||
@ -25,4 +24,4 @@ With the `AutoModelForCausalLMWithValueHead` class TRL supports all decoder mode
|
||||
|
||||
## create_reference_model
|
||||
|
||||
[[autodoc]] create_reference_model
|
||||
[[autodoc]] create_reference_model
|
||||
|
@ -14,11 +14,11 @@ You need to address this approach in three stages that we summarize as follows:
|
||||
2- Train a reward model using `peft`. This is required in order to re-use the adapter during the RL optimisation process (step 3 below). We show an example of leveraging the `RewardTrainer` from TRL in [this example](https://github.com/huggingface/trl/tree/main/examples/scripts/reward_modeling.py)
|
||||
3- Fine tune new adapters on the base model using PPO and the reward adapter. ("0 abstraction RL")
|
||||
|
||||
Make sure to use the same model (i.e. same architecture and same weights) for the stages 2 & 3.
|
||||
Make sure to use the same model (i.e. same architecture and same weights) for the stages 2 & 3.
|
||||
|
||||
## Quickstart
|
||||
|
||||
Let us assume you have trained your reward adapter on `llama-7b` model using `RewardTrainer` and pushed the weights on the hub under `trl-lib/llama-7b-hh-rm-adapter`.
|
||||
Let us assume you have trained your reward adapter on `llama-7b` model using `RewardTrainer` and pushed the weights on the hub under `trl-lib/llama-7b-hh-rm-adapter`.
|
||||
When doing PPO, before passing the model to `PPOTrainer` create your model as follows:
|
||||
|
||||
```python
|
||||
@ -48,6 +48,7 @@ trainer = PPOTrainer(
|
||||
|
||||
...
|
||||
```
|
||||
|
||||
Then inside your PPO training loop, call the `compute_reward_score` method by accessing the `model` attribute from `PPOTrainer`.
|
||||
|
||||
```python
|
||||
@ -56,9 +57,9 @@ rewards = trainer.model.compute_reward_score(**inputs)
|
||||
|
||||
## Advanced usage
|
||||
|
||||
### Control on the adapter name
|
||||
### Control on the adapter name
|
||||
|
||||
If you are familiar with the `peft` library, you know that you can use multiple adapters inside the same model. What you can do is train multiple adapters on the same base model to fine-tune on different policies.
|
||||
If you are familiar with the `peft` library, you know that you can use multiple adapters inside the same model. What you can do is train multiple adapters on the same base model to fine-tune on different policies.
|
||||
In this case, you want to be able to control the adapter name you want to activate back, after retrieving the reward. For that, simply pass the appropriate `adapter_name` to `ppo_adapter_name` argument when calling `compute_reward_score`.
|
||||
|
||||
```python
|
||||
@ -71,6 +72,7 @@ rewards = trainer.model.compute_reward_score(**inputs, ppo_adapter_name=adapter_
|
||||
|
||||
For more memory efficient fine-tuning, you can load your base model in 8-bit or 4-bit while keeping the adapters in the default precision (float32).
|
||||
Just pass the appropriate arguments (i.e. `load_in_8bit=True` or `load_in_4bit=True`) to `AutoModelForCausalLMWithValueHead.from_pretrained` as follows (assuming you have installed `bitsandbytes`):
|
||||
|
||||
```python
|
||||
model_name = "llama-7b"
|
||||
rm_adapter_id = "trl-lib/llama-7b-hh-rm-adapter"
|
||||
|
@ -1,16 +1,16 @@
|
||||
# Nash-MD Trainer
|
||||
|
||||
[](https://huggingface.co/models?other=nash-md,trl)
|
||||
[](https://huggingface.co/models?other=nash-md,trl)
|
||||
|
||||
## Overview
|
||||
|
||||
Nash-MD was proposed in the paper [Nash Learning from Human Feedback](https://huggingface.co/papers/2312.00886) by Rémi Munos, [Michal Valko](https://huggingface.co/misovalko), Daniele Calandriello, Mohammad Gheshlaghi Azar, Mark Rowland, Daniel Guo, Yunhao Tang, Matthieu Geist, Thomas Mésnard, and Andrea Michi.
|
||||
Nash-MD was proposed in the paper [Nash Learning from Human Feedback](https://huggingface.co/papers/2312.00886) by Rémi Munos, [Michal Valko](https://huggingface.co/misovalko), Daniele Calandriello, Mohammad Gheshlaghi Azar, Mark Rowland, Daniel Guo, Yunhao Tang, Matthieu Geist, Thomas Mésnard, and Andrea Michi.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
> Reinforcement learning from human feedback (RLHF) has emerged as the main paradigm for aligning large language models (LLMs) with human preferences. Typically, RLHF involves the initial step of learning a reward model from human feedback, often expressed as preferences between pairs of text generations produced by a pre-trained LLM. Subsequently, the LLM's policy is fine-tuned by optimizing it to maximize the reward model through a reinforcement learning algorithm. However, an inherent limitation of current reward models is their inability to fully represent the richness of human preferences and their dependency on the sampling distribution. In this study, we introduce an alternative pipeline for the fine-tuning of LLMs using pairwise human feedback. Our approach entails the initial learning of a preference model, which is conditioned on two inputs given a prompt, followed by the pursuit of a policy that consistently generates responses preferred over those generated by any competing policy, thus defining the Nash equilibrium of this preference model. We term this approach Nash learning from human feedback (NLHF). In the context of a tabular policy representation, we present a novel algorithmic solution, Nash-MD, founded on the principles of mirror descent. This algorithm produces a sequence of policies, with the last iteration converging to the regularized Nash equilibrium. Additionally, we explore parametric representations of policies and introduce gradient descent algorithms for deep-learning architectures. To demonstrate the effectiveness of our approach, we present experimental results involving the fine-tuning of a LLM for a text summarization task. We believe NLHF offers a compelling avenue for preference learning and policy optimization with the potential of advancing the field of aligning LLMs with human preferences.
|
||||
|
||||
This post-training method was contributed by [Kashif Rasul](https://huggingface.co/kashif) and [Daniil Tiapkin](https://huggingface.co/dtiapkin), [Pierre Ménard](https://huggingface.co/menardprr), Daniele Calandriello and [Quentin Gallouédec](https://huggingface.co/qgallouedec).
|
||||
This post-training method was contributed by [Kashif Rasul](https://huggingface.co/kashif) and [Daniil Tiapkin](https://huggingface.co/dtiapkin), [Pierre Ménard](https://huggingface.co/menardprr), Daniele Calandriello and [Quentin Gallouédec](https://huggingface.co/qgallouedec).
|
||||
|
||||
## Quick start
|
||||
|
||||
|
@ -1,10 +1,10 @@
|
||||
# Online DPO Trainer
|
||||
|
||||
[](https://huggingface.co/models?other=online-dpo,trl)
|
||||
[](https://huggingface.co/models?other=online-dpo,trl)
|
||||
|
||||
## Overview
|
||||
## Overview
|
||||
|
||||
Online DPO was proposed in [Direct Language Model Alignment from Online AI Feedback](https://huggingface.co/papers/2402.04792) by Shangmin Guo, Biao Zhang, Tianlin Liu, Tianqi Liu, Misha Khalman, Felipe Llinares, Alexandre Rame, Thomas Mesnard, Yao Zhao, Bilal Piot, Johan Ferret, and Mathieu Blondel.
|
||||
Online DPO was proposed in [Direct Language Model Alignment from Online AI Feedback](https://huggingface.co/papers/2402.04792) by Shangmin Guo, Biao Zhang, Tianlin Liu, Tianqi Liu, Misha Khalman, Felipe Llinares, Alexandre Rame, Thomas Mesnard, Yao Zhao, Bilal Piot, Johan Ferret, and Mathieu Blondel.
|
||||
|
||||
The abstract from the paper is the following:
|
||||
|
||||
@ -112,7 +112,6 @@ This callback logs the model's generated completions directly to Weights & Biase
|
||||
|
||||

|
||||
|
||||
|
||||
## Example script
|
||||
|
||||
We provide an example script to train a model using the online DPO method. The script is available in [`examples/scripts/dpo_online.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo_online.py)
|
||||
@ -153,8 +152,7 @@ While training and evaluating, we record the following reward metrics. Here is a
|
||||
|
||||
To validate the online DPO implementation works, we ran experiments with the Pythia 1B, 2.8B, and 6.9B models on a single node of 8 x H100s. Here are the commands we used to run the experiments. We take the SFT / RM models directly from [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://huggingface.co/papers/2403.17031).
|
||||
|
||||
|
||||
```
|
||||
```shell
|
||||
# 1B Online DPO experiment
|
||||
accelerate launch --config_file examples/accelerate_configs/multi_gpu.yaml \
|
||||
examples/scripts/dpo_online.py \
|
||||
@ -213,9 +211,8 @@ accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml
|
||||
|
||||
Checkpoints and experiment tracking are available at:
|
||||
|
||||
- [🤗 Model checkpoints](https://huggingface.co/collections/trl-lib/online-dpo-66acd3fa38a331a9cd457b07)
|
||||
- [🐝 Tracked experiment](https://wandb.ai/huggingface/trl/reports/Online-DPO-experiments-for-TL-DR-summarisation--Vmlldzo5MTczMDU0)
|
||||
|
||||
* [🤗 Model checkpoints](https://huggingface.co/collections/trl-lib/online-dpo-66acd3fa38a331a9cd457b07)
|
||||
* [🐝 Tracked experiment](https://wandb.ai/huggingface/trl/reports/Online-DPO-experiments-for-TL-DR-summarisation--Vmlldzo5MTczMDU0)
|
||||
|
||||
To evaluate, we use [vLLM](https://github.com/vllm-project/vllm) to load the checkpoints and GPT-4o mini as a judge model to evaluate the generated TL;DR against the reference TL;DR.
|
||||
For more information on how to use judges, see [Judges](judges).
|
||||
|
@ -1,6 +1,6 @@
|
||||
# ORPO Trainer
|
||||
|
||||
[](https://huggingface.co/models?other=orpo,trl) [](https://github.com/huggingface/smol-course/tree/main/2_preference_alignment)
|
||||
[](https://huggingface.co/models?other=orpo,trl) [](https://github.com/huggingface/smol-course/tree/main/2_preference_alignment)
|
||||
|
||||
## Overview
|
||||
|
||||
@ -54,7 +54,7 @@ accelerate launch train_orpo.py
|
||||
|
||||
Distributed across 8 GPUs, the training takes approximately 30 minutes. You can verify the training progress by checking the reward graph. An increasing trend in the reward margin indicates that the model is improving and generating better responses over time.
|
||||
|
||||

|
||||

|
||||
|
||||
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-ORPO) performs, you can use the [Transformers Chat CLI](https://huggingface.co/docs/transformers/quicktour#chat-with-text-generation-models).
|
||||
|
||||
@ -64,11 +64,11 @@ What is the best programming language?
|
||||
|
||||
<strong><span style="color: blue;"><trl-lib/Qwen2-0.5B-ORPO>:</span></strong>
|
||||
It's challenging to determine the best programming language as no one language is perfect, as the complexity of a task and the type of project are significant factors. Some popular languages include Java, Python, JavaScript, and
|
||||
C++. If you have specific needs or requirements for a specific project, it's important to choose the language that best suits those needs.
|
||||
C++. If you have specific needs or requirements for a specific project, it's important to choose the language that best suits those needs.
|
||||
|
||||
Here are some other factors to consider when choosing a programming language for a project:
|
||||
|
||||
<strong><span style="color: green;">• Language proficiency:</span></strong> A good programming language is more likely to be easy to understand and use, and will allow developers to collaborate on projects more efficiently.
|
||||
<strong><span style="color: green;">• Language proficiency:</span></strong> A good programming language is more likely to be easy to understand and use, and will allow developers to collaborate on projects more efficiently.
|
||||
<strong><span style="color: green;">• Ease of use:</span></strong> There are tools and libraries available to make programming more accessible, so developers should choose a language that can help them get started easier.
|
||||
<strong><span style="color: green;">• Code readability:</span></strong> A clear and concise codebase should be easy to read and understand, especially when working with large projects.
|
||||
<strong><span style="color: green;">• Tool and framework support:</span></strong> There are numerous libraries available for Python, Java, and JavaScript, along with tools like IDEs and static code analysis tools.
|
||||
@ -118,7 +118,7 @@ While training and evaluating, we record the following reward metrics:
|
||||
- `log_odds_chosen`: the mean log odds ratio of the chosen responses over the rejected responses
|
||||
- `log_odds_ratio`: the mean of the `log(sigmoid(log_odds_chosen))`
|
||||
- `nll_loss`: the mean negative log likelihood loss from the SFT part of the loss over chosen responses
|
||||
|
||||
|
||||
## ORPOTrainer
|
||||
|
||||
[[autodoc]] ORPOTrainer
|
||||
|
@ -170,7 +170,7 @@ $$
|
||||
}
|
||||
$$
|
||||
|
||||
Despite \\( \textcolor{red}{\pi_{\text{inference}}} \\) and \\( \textcolor{blue}{\pi_{\text{training}}} \\) sharing the same model parameters \\( \theta \\), they can produce significantly different token probabilities. This unexpected behavior implicitly breaks the on-policy assumption, and silently turns training off-policy.
|
||||
Despite \\( \textcolor{red}{\pi_{\text{inference}}} \\) and \\( \textcolor{blue}{\pi_{\text{training}}} \\) sharing the same model parameters \\( \theta \\), they can produce significantly different token probabilities. This unexpected behavior implicitly breaks the on-policy assumption, and silently turns training off-policy.
|
||||
|
||||
Truncated Importance Sampling (TIS) addresses this issue by adapting the model update via importance-sampling correction. The gradient computation of the aforementioned PPO objective becomes
|
||||
|
||||
@ -458,10 +458,7 @@ trainer = SFTTrainer(
|
||||
Dynamic Fine-Tuning (DFT) improves the generalization of Large Language Models (LLMs) by dynamically rescaling gradients, outperforming standard Supervised Fine-Tuning (SFT) and showing competitive results in offline reinforcement learning.
|
||||
|
||||
$$
|
||||
\mathcal{L}_{\text{DFT}}(\theta)
|
||||
= \mathbb{E}_{(x,y) \sim \mathcal{D}} \left[ - \sum_{t=1}^{|y|}
|
||||
\textcolor{red}{\text{sg}\big(\pi_\theta(y_t \mid y_{<t}, x)\big)}
|
||||
\; \log \pi_\theta(y_t \mid y_{<t}, x) \right]
|
||||
\mathcal{L}_{\text{DFT}}(\theta) = \mathbb{E}_{(x,y) \sim \mathcal{D}} \left[ - \sum_{t=1}^{|y|} \textcolor{red}{\text{sg}\big(\pi_\theta(y_t \mid y_{<t}, x)\big)} \; \log \pi_\theta(y_t \mid y_{<t}, x) \right]
|
||||
$$
|
||||
|
||||
where \\( \text{sg}(\cdot) \\) is the stop-gradient operator. To use DFT with SFT as described in the paper, you can use the `loss_type="dft"` argument:
|
||||
|
@ -3,17 +3,10 @@
|
||||
The notebooks and scripts in these examples show how to use Low Rank Adaptation (LoRA) to fine-tune models in a memory efficient manner. Most of PEFT methods supported in peft library but note that some PEFT methods such as Prompt tuning are not supported.
|
||||
For more information on LoRA, see the [original paper](https://huggingface.co/papers/2106.09685).
|
||||
|
||||
Here's an overview of the `peft`-enabled notebooks and scripts in the [trl repository](https://github.com/huggingface/trl/tree/main/examples):
|
||||
|
||||
| File | Task | Description | Colab link |
|
||||
|---|---| --- |
|
||||
| [`stack_llama/rl_training.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py) | RLHF | Distributed fine-tuning of the 7b parameter LLaMA models with a learned reward model and `peft`. | |
|
||||
| [`stack_llama/reward_modeling.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama/scripts/reward_modeling.py) | Reward Modeling | Distributed training of the 7b parameter LLaMA reward model with `peft`. | |
|
||||
| [`stack_llama/supervised_finetuning.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama/scripts/supervised_finetuning.py) | SFT | Distributed instruction/supervised fine-tuning of the 7b parameter LLaMA model with `peft`. | |
|
||||
|
||||
## Installation
|
||||
|
||||
Note: peft is in active development, so we install directly from their Github page.
|
||||
Peft also relies on the latest version of transformers.
|
||||
Peft also relies on the latest version of transformers.
|
||||
|
||||
```bash
|
||||
pip install trl[peft]
|
||||
@ -27,7 +20,7 @@ Note: if you don't want to log with `wandb` remove `log_with="wandb"` in the scr
|
||||
|
||||
## How to use it?
|
||||
|
||||
Simply declare a `PeftConfig` object in your script and pass it through `.from_pretrained` to load the TRL+PEFT model.
|
||||
Simply declare a `PeftConfig` object in your script and pass it through `.from_pretrained` to load the TRL+PEFT model.
|
||||
|
||||
```python
|
||||
from peft import LoraConfig
|
||||
@ -47,7 +40,9 @@ model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
peft_config=lora_config,
|
||||
)
|
||||
```
|
||||
|
||||
And if you want to load your model in 8bit precision:
|
||||
|
||||
```python
|
||||
pretrained_model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
config.model_name,
|
||||
@ -55,7 +50,9 @@ pretrained_model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
peft_config=lora_config,
|
||||
)
|
||||
```
|
||||
|
||||
... or in 4bit precision:
|
||||
|
||||
```python
|
||||
pretrained_model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
config.model_name,
|
||||
@ -64,7 +61,6 @@ pretrained_model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
)
|
||||
```
|
||||
|
||||
|
||||
## Launch scripts
|
||||
|
||||
The `trl` library is powered by `accelerate`. As such it is best to configure and launch trainings with the following commands:
|
||||
@ -77,6 +73,7 @@ accelerate launch examples/scripts/ppo.py --use_peft # launch`es training
|
||||
## Using `trl` + `peft` and Data Parallelism
|
||||
|
||||
You can scale up to as many GPUs as you want, as long as you are able to fit the training process in a single device. The only tweak you need to apply is to load the model as follows:
|
||||
|
||||
```python
|
||||
from peft import LoraConfig
|
||||
...
|
||||
@ -94,7 +91,9 @@ pretrained_model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
peft_config=lora_config,
|
||||
)
|
||||
```
|
||||
|
||||
And if you want to load your model in 8bit precision:
|
||||
|
||||
```python
|
||||
pretrained_model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
config.model_name,
|
||||
@ -102,7 +101,9 @@ pretrained_model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
load_in_8bit=True,
|
||||
)
|
||||
```
|
||||
|
||||
... or in 4bit precision:
|
||||
|
||||
```python
|
||||
pretrained_model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
config.model_name,
|
||||
@ -110,21 +111,20 @@ pretrained_model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
load_in_4bit=True,
|
||||
)
|
||||
```
|
||||
|
||||
Finally, make sure that the rewards are computed on correct device as well, for that you can use `ppo_trainer.model.current_device`.
|
||||
|
||||
## Naive pipeline parallelism (NPP) for large models (>60B models)
|
||||
|
||||
The `trl` library also supports naive pipeline parallelism (NPP) for large models (>60B models). This is a simple way to parallelize the model across multiple GPUs.
|
||||
The `trl` library also supports naive pipeline parallelism (NPP) for large models (>60B models). This is a simple way to parallelize the model across multiple GPUs.
|
||||
This paradigm, termed as "Naive Pipeline Parallelism" (NPP) is a simple way to parallelize the model across multiple GPUs. We load the model and the adapters across multiple GPUs and the activations and gradients will be naively communicated across the GPUs. This supports `int8` models as well as other `dtype` models.
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl-npp.png">
|
||||
</div>
|
||||

|
||||
|
||||
### How to use NPP?
|
||||
|
||||
Simply load your model with a custom `device_map` argument on the `from_pretrained` to split your model across multiple devices. Check out this [nice tutorial](https://github.com/huggingface/blog/blob/main/accelerate-large-models.md) on how to properly create a `device_map` for your model.
|
||||
|
||||
Simply load your model with a custom `device_map` argument on the `from_pretrained` to split your model across multiple devices. Check out this [nice tutorial](https://github.com/huggingface/blog/blob/main/accelerate-large-models.md) on how to properly create a `device_map` for your model.
|
||||
|
||||
Also make sure to have the `lm_head` module on the first GPU device as it may throw an error if it is not on the first device. As this time of writing, you need to install the `main` branch of `accelerate`: `pip install git+https://github.com/huggingface/accelerate.git@main` and `peft`: `pip install git+https://github.com/huggingface/peft.git@main`.
|
||||
|
||||
### Launch scripts
|
||||
|
@ -1,10 +1,11 @@
|
||||
# PPO Trainer
|
||||
|
||||
[](https://huggingface.co/models?other=ppo,trl)
|
||||
[](https://huggingface.co/models?other=ppo,trl)
|
||||
|
||||
TRL supports training LLMs with [Proximal Policy Optimization (PPO)](https://huggingface.co/papers/1707.06347).
|
||||
|
||||
References:
|
||||
|
||||
- [Fine-Tuning Language Models from Human Preferences](https://github.com/openai/lm-human-preferences)
|
||||
- [Learning to Summarize from Human Feedback](https://github.com/openai/summarize-from-feedback)
|
||||
- [The N Implementation Details of RLHF with PPO](https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo)
|
||||
@ -31,49 +32,45 @@ python examples/scripts/ppo/ppo.py \
|
||||
--missing_eos_penalty 1.0
|
||||
```
|
||||
|
||||
|
||||
## Explanation of the logged metrics
|
||||
|
||||
The logged metrics are as follows. Here is an example [tracked run at Weights and Biases](https://wandb.ai/huggingface/trl/runs/dd2o3g35)
|
||||
|
||||
* `eps`: Tracks the number of episodes per second.
|
||||
* `objective/kl`: The mean Kullback-Leibler (KL) divergence between the current policy and reference policy.
|
||||
* `objective/entropy`: The mean entropy of the policy, indicating the randomness of the actions chosen by the policy.
|
||||
* `objective/non_score_reward`: The mean reward from non-score-related sources, basically `beta * kl.sum(1)`, where `beta` is the KL penalty coefficient and `kl` is the per-token KL divergence.
|
||||
* `objective/rlhf_reward`: The mean RLHF reward, which is `score - non_score_reward`.
|
||||
* `objective/scores`: The mean scores returned by the reward model / environment.
|
||||
* `policy/approxkl_avg`: The average approximate KL divergence between consecutive PPO policies. Note that this is not the same as `objective/kl`.
|
||||
* `policy/clipfrac_avg`: The average fraction of policy updates that are clipped, indicating how often the policy updates are constrained to prevent large changes.
|
||||
* `loss/policy_avg`: The average policy loss, indicating how well the policy is performing.
|
||||
* `loss/value_avg`: The average value loss, indicating the difference between the predicted value and the actual reward.
|
||||
* `val/clipfrac_avg`: The average fraction of value function updates that are clipped, similar to policy/clipfrac_avg but for the value function.
|
||||
* `policy/entropy_avg`: The average entropy of the policy during training, indicating how diverse the policy's actions are.
|
||||
* `val/ratio`: The mean ratio of the current policy probability to the old policy probability, providing a measure of how much the policy has changed.
|
||||
* `val/ratio_var`: The variance of the `val/ratio`, indicating the variability in policy changes.
|
||||
* `val/num_eos_tokens`: The number of end-of-sequence (EOS) tokens generated, which can indicate the number of complete responses.
|
||||
* `lr`: lr: The current learning rate used by the optimizer.
|
||||
* `episode`: episode: The current episode count in the training process.
|
||||
|
||||
- `eps`: Tracks the number of episodes per second.
|
||||
- `objective/kl`: The mean Kullback-Leibler (KL) divergence between the current policy and reference policy.
|
||||
- `objective/entropy`: The mean entropy of the policy, indicating the randomness of the actions chosen by the policy.
|
||||
- `objective/non_score_reward`: The mean reward from non-score-related sources, basically `beta * kl.sum(1)`, where `beta` is the KL penalty coefficient and `kl` is the per-token KL divergence.
|
||||
- `objective/rlhf_reward`: The mean RLHF reward, which is `score - non_score_reward`.
|
||||
- `objective/scores`: The mean scores returned by the reward model / environment.
|
||||
- `policy/approxkl_avg`: The average approximate KL divergence between consecutive PPO policies. Note that this is not the same as `objective/kl`.
|
||||
- `policy/clipfrac_avg`: The average fraction of policy updates that are clipped, indicating how often the policy updates are constrained to prevent large changes.
|
||||
- `loss/policy_avg`: The average policy loss, indicating how well the policy is performing.
|
||||
- `loss/value_avg`: The average value loss, indicating the difference between the predicted value and the actual reward.
|
||||
- `val/clipfrac_avg`: The average fraction of value function updates that are clipped, similar to policy/clipfrac_avg but for the value function.
|
||||
- `policy/entropy_avg`: The average entropy of the policy during training, indicating how diverse the policy's actions are.
|
||||
- `val/ratio`: The mean ratio of the current policy probability to the old policy probability, providing a measure of how much the policy has changed.
|
||||
- `val/ratio_var`: The variance of the `val/ratio`, indicating the variability in policy changes.
|
||||
- `val/num_eos_tokens`: The number of end-of-sequence (EOS) tokens generated, which can indicate the number of complete responses.
|
||||
- `lr`: lr: The current learning rate used by the optimizer.
|
||||
- `episode`: episode: The current episode count in the training process.
|
||||
|
||||
## Cookbook
|
||||
|
||||
* Debugging TIP: `objective/rlhf_reward`: this is the ultimate objective of the RLHF training. If training works as intended, this metric should keep going up.
|
||||
* Debugging TIP: `val/ratio`: this number should float around 1.0, and it gets clipped by `--cliprange 0.2` with PPO's surrogate loss. So if this `ratio` is too high like 2.0 or 1000.0 or too small like 0.1, it means the updates between consecutive policies are too drastic. You should try understand why this is happening and try to fix it.
|
||||
* Memory TIP: If you are running out of memory, you can try to reduce the `--per_device_train_batch_size` or increase the `--gradient_accumulation_steps` to reduce the memory footprint.
|
||||
* Memory TIP: If you have multiple GPUs, you can also run training with DeepSpeed stage 3 to reduce the memory footprint `accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml`.
|
||||
* Usage TIP: We recommend to use the "EOS trick" via `--missing_eos_penalty`, which subtracts a static scalar penalty from the score of completions that do not end with an EOS token. This can help the model learn to generate more coherent completions.
|
||||
|
||||
- Debugging TIP: `objective/rlhf_reward`: this is the ultimate objective of the RLHF training. If training works as intended, this metric should keep going up.
|
||||
- Debugging TIP: `val/ratio`: this number should float around 1.0, and it gets clipped by `--cliprange 0.2` with PPO's surrogate loss. So if this `ratio` is too high like 2.0 or 1000.0 or too small like 0.1, it means the updates between consecutive policies are too drastic. You should try understand why this is happening and try to fix it.
|
||||
- Memory TIP: If you are running out of memory, you can try to reduce the `--per_device_train_batch_size` or increase the `--gradient_accumulation_steps` to reduce the memory footprint.
|
||||
- Memory TIP: If you have multiple GPUs, you can also run training with DeepSpeed stage 3 to reduce the memory footprint `accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml`.
|
||||
- Usage TIP: We recommend to use the "EOS trick" via `--missing_eos_penalty`, which subtracts a static scalar penalty from the score of completions that do not end with an EOS token. This can help the model learn to generate more coherent completions.
|
||||
|
||||
## What is my model doing exactly?
|
||||
|
||||
To help you understand what your model is doing, we periodically log some sample completions from the model. Here is an example of a completion. In an example [tracked run at Weights and Biases](https://wandb.ai/huggingface/trl/runs/dd2o3g35), it looks like the following, allowing you to see the model's response at different stages of training. By default we generate `--num_sample_generations 10` during training, but you can customize the number of generations.
|
||||
|
||||

|
||||

|
||||
|
||||
In the logs the sampled generations look like
|
||||
|
||||
In the logs the sampled generations look like
|
||||
|
||||
```
|
||||
```txt
|
||||
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┓
|
||||
┃ query ┃ model response ┃ score ┃
|
||||
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━┩
|
||||
@ -177,7 +174,7 @@ This PPO implementation is based on the [The N+ Implementation Details of RLHF w
|
||||
|
||||
To validate the PPO implementation works, we ran experiment on the 1B model. Here are the command we used to run the experiment. We take the SFT / RM models directly from [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://huggingface.co/papers/2403.17031).
|
||||
|
||||
```
|
||||
```shell
|
||||
accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml \
|
||||
examples/scripts/ppo/ppo_tldr.py \
|
||||
--output_dir models/minimal/ppo_tldr \
|
||||
@ -212,8 +209,7 @@ The PPO checkpoint gets a 64.7% preferred rate vs the 33.0% preference rate of t
|
||||
|
||||
Metrics:
|
||||
|
||||

|
||||
|
||||

|
||||
|
||||
```bash
|
||||
# pip install openrlbenchmark==0.2.1a5
|
||||
|
@ -1,6 +1,6 @@
|
||||
# PRM Trainer
|
||||
|
||||
[](https://huggingface.co/models?other=prm,trl)
|
||||
[](https://huggingface.co/models?other=prm,trl)
|
||||
|
||||
> [!WARNING]
|
||||
> PRM Trainer is an experimental API which is subject to change at any time.
|
||||
@ -15,7 +15,6 @@ The abstract from the paper is the following:
|
||||
|
||||
This post-training method was contributed by [Gaetan Lopez](https://github.com/gaetanlop), [Lewis Tunstall](https://huggingface.co/lewtun), [Quentin Gallouédec](https://huggingface.co/qgallouedec) and [Agustín Piqueres](https://huggingface.co/plaguss).
|
||||
|
||||
|
||||
## Quick start
|
||||
|
||||
This example demonstrates how to train a model using the PRM method. We use the [Qwen 0.5B model](https://huggingface.co/Qwen/Qwen2-0.5B) as the base model. We use the stepwise supervision data from the [Math Shepherd dataset](https://huggingface.co/datasets/trl-lib/math_shepherd). You can view the data in the dataset here:
|
||||
@ -54,7 +53,6 @@ Distributed across 8 GPUs, the training takes approximately 1 hour.
|
||||
|
||||
To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-Reward-Math-Sheperd) performs, you can use the following script.
|
||||
|
||||
|
||||
```python
|
||||
from datasets import load_dataset
|
||||
from transformers import pipeline
|
||||
|
@ -7,9 +7,7 @@
|
||||
|
||||
Sequence lengths in the dataset can vary widely. When data is batched, sequences are padded to match the longest one in the batch, which can cause high memory usage, even if most sequences are relatively short.
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/why_you_should_truncate.png" alt="Truncation prompt-completion" width="600"/>
|
||||
</div>
|
||||

|
||||
|
||||
To reduce memory usage, it's important to truncate sequences to a reasonable length. While TRL trainers truncate sequences by default, you may want to adjust the default truncation length to better align with your specific use case.
|
||||
|
||||
@ -18,9 +16,7 @@ To reduce memory usage, it's important to truncate sequences to a reasonable len
|
||||
|
||||
DPO truncation is applied first to the prompt and to the completion via the `max_prompt_length` and `max_completion_length` parameters. The `max_length` parameter is then used to truncate the resulting sequence.
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/truncation_prompt_completion.png" alt="DPO truncation" width="600"/>
|
||||
</div>
|
||||

|
||||
|
||||
To set the truncation parameters, use the following code snippet:
|
||||
|
||||
@ -43,9 +39,7 @@ training_args = DPOConfig(..., max_completion_length=...)
|
||||
|
||||
SFT truncation is applied to the input sequence via the `max_length` parameter.
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/truncation_input_ids.png" alt="Truncation input ids" width="600"/>
|
||||
</div>
|
||||

|
||||
|
||||
To set the truncation parameter, use the following code snippet:
|
||||
|
||||
@ -71,16 +65,14 @@ To help you choose an appropriate value, we provide a utility to visualize the s
|
||||
> [!TIP]
|
||||
> This technique applies only to SFT.
|
||||
|
||||
|
||||
[Truncation](#truncation) has several drawbacks:
|
||||
|
||||
1. **Loss of information**: Key data at the end of a sequence may be discarded.
|
||||
2. **Choosing truncation length**: Too short loses data; too long undermines efficiency.
|
||||
|
||||
Packing, introduced in [Raffel et al., 2020](https://huggingface.co/papers/1910.10683), addresses these issues by grouping sequences instead of truncating. It concatenates and splits dataset sequences into the desired lengths.
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/packing_2.png" alt="Packing" width="600"/>
|
||||
</div>
|
||||

|
||||
|
||||
Packing reduces padding by merging several sequences in one row when possible. We use an advanced method to be near-optimal in the way we pack the dataset. To enable packing, use `packing=True` in the [`SFTConfig`].
|
||||
|
||||
@ -142,9 +134,7 @@ training_args = KTOConfig(..., use_liger_loss=True)
|
||||
|
||||
Padding-free batching is an alternative approach for reducing memory usage. In this method, a batch is first sampled and then flattened into a single sequence, avoiding padding. Unlike packing, which can result in incomplete sequences by combining parts of different samples, padding-free batching ensures that all sequences remain complete and intact.
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/padding-free.png" alt="Padding-free batching" width="600"/>
|
||||
</div>
|
||||

|
||||
|
||||
> [!WARNING]
|
||||
> It's highly recommended to use padding-free batching with **FlashAttention 2** or **FlashAttention 3**. Otherwise, you may encounter batch contamination issues.
|
||||
|
@ -1,6 +1,6 @@
|
||||
# Reward Modeling
|
||||
|
||||
[](https://huggingface.co/models?other=reward-trainer,trl)
|
||||
[](https://huggingface.co/models?other=reward-trainer,trl)
|
||||
|
||||
## Overview
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
# RLOO Trainer
|
||||
|
||||
[](https://huggingface.co/models?other=rloo,trl)
|
||||
[](https://huggingface.co/models?other=rloo,trl)
|
||||
|
||||
## Overview
|
||||
|
||||
@ -101,14 +101,13 @@ where \\( \beta > 0 \\) controls the strength of the KL penalty.
|
||||
|
||||
### Computing the advantage
|
||||
|
||||
Once the rewards for each completion have been computed, we calculate a baseline as the average reward of all other samples in the same batch, excluding the current sample. This baseline is used to reduce the variance of the policy gradient estimate. The advantage for each completion is then obtained as the difference between its own reward and this leave-one-out baseline.
|
||||
Once the rewards for each completion have been computed, we calculate a baseline as the average reward of all other samples in the same batch, excluding the current sample. This baseline is used to reduce the variance of the policy gradient estimate. The advantage for each completion is then obtained as the difference between its own reward and this leave-one-out baseline.
|
||||
|
||||
Formally, for a batch of G completions, the baseline for completion is:
|
||||
$$
|
||||
b_i = \frac{1}{G-1} \sum_{j \neq i} r_j
|
||||
$$
|
||||
|
||||
|
||||
and then the advantage for each completion is computed as the difference between its reward and the baseline:
|
||||
|
||||
$$
|
||||
@ -151,9 +150,9 @@ While training and evaluating, we record the following reward metrics:
|
||||
- `entropy`: Average entropy of token predictions across generated completions. (If `mask_truncated_completions=True`, masked sequences tokens are excluded.)
|
||||
- `kl`: The average KL divergence between the model and the reference model, calculated over generated completions. Logged only if `beta` is nonzero.
|
||||
- `clip_ratio/region_mean`: The ratio of sequence probabilities where the RLOO objective is clipped to stay within the trust region:
|
||||
$$
|
||||
\text{clip}\left( r_{i}(\theta), 1 - \epsilon_\mathrm{low}, 1 + \epsilon_\mathrm{high} \right)\,, \qquad r_{i}(\theta) = \frac{\pi_\theta(o_{i} \mid q)}{\pi_{\theta_{\text{old}}}(o_{i} \mid q)}\,.
|
||||
$$
|
||||
$$
|
||||
\text{clip}\left( r_{i}(\theta), 1 - \epsilon_\mathrm{low}, 1 + \epsilon_\mathrm{high} \right)\,, \qquad r_{i}(\theta) = \frac{\pi_\theta(o_{i} \mid q)}{\pi_{\theta_{\text{old}}}(o_{i} \mid q)}\,.
|
||||
$$
|
||||
|
||||
A higher value means more samples are clipped, which constrains how much the policy $\pi_\theta$ can change.
|
||||
- `clip_ratio/low_mean`: The average ratio of sequence probabilities that were clipped on the lower bound of the trust region: \\(r_{i,t}(\theta) < 1 - \epsilon_\mathrm{low}\\)
|
||||
@ -166,6 +165,7 @@ $$
|
||||
### Speed up training with vLLM-powered generation
|
||||
|
||||
Generation is often the main bottleneck when training with online methods. To accelerate generation, you can use [vLLM](https://github.com/vllm-project/vllm), a high-throughput, low-latency inference engine for LLMs. To enable it, first install the package with
|
||||
|
||||
```shell
|
||||
pip install trl[vllm]
|
||||
```
|
||||
@ -177,11 +177,13 @@ We support two ways of using vLLM during training: **server mode** and **colocat
|
||||
In this mode, vLLM runs in a separate process (and using separate GPUs) and communicates with the trainer via HTTP. This is ideal if you have dedicated GPUs for inference.
|
||||
|
||||
1. **Start the vLLM server**:
|
||||
|
||||
```bash
|
||||
trl vllm-serve --model <model_name>
|
||||
```
|
||||
|
||||
2. **Enable server mode in your training script**:
|
||||
|
||||
```python
|
||||
from trl import RLOOConfig
|
||||
|
||||
@ -214,12 +216,7 @@ training_args = RLOOConfig(
|
||||
>
|
||||
> We provide a [HF Space](https://huggingface.co/spaces/trl-lib/recommend-vllm-memory) to help estimate the recommended GPU memory utilization based on your model configuration and experiment settings. Simply use it as follows to get `vllm_gpu_memory_utilization` recommendation:
|
||||
>
|
||||
> <iframe
|
||||
> src="https://trl-lib-recommend-vllm-memory.hf.space"
|
||||
> frameborder="0"
|
||||
> width="850"
|
||||
> height="450"
|
||||
> ></iframe>
|
||||
> <iframe src="https://trl-lib-recommend-vllm-memory.hf.space" frameborder="0" width="850" height="450"></iframe>
|
||||
>
|
||||
> If the recommended value does not work in your environment, we suggest adding a small buffer (e.g., +0.05 or +0.1) to the recommended value to ensure stability.
|
||||
>
|
||||
@ -418,6 +415,7 @@ You can test this function as follows:
|
||||
>>> reward_func(prompts=prompts, completions=completions, ground_truth=ground_truth)
|
||||
[1.0, 0.0]
|
||||
```
|
||||
|
||||
#### Example 4: Multi-task reward functions
|
||||
|
||||
Below is an example of using multiple reward functions in the [`RLOOTrainer`]. In this example, we define two task-specific reward functions: `math_reward_func` and `coding_reward_func`. The `math_reward_func` rewards math problems based on their correctness, while the `coding_reward_func` rewards coding problems based on whether the solution works.
|
||||
@ -478,8 +476,6 @@ In this example, the `math_reward_func` and `coding_reward_func` are designed to
|
||||
|
||||
Note that the [`RLOOTrainer`] will ignore the `None` rewards returned by the reward functions and only consider the rewards returned by the relevant functions. This ensures that the model is trained on the relevant tasks and ignores the tasks for which there is no relevant reward function.
|
||||
|
||||
|
||||
|
||||
#### Passing the reward function to the trainer
|
||||
|
||||
To use your custom reward function, pass it to the [`RLOOTrainer`] as follows:
|
||||
|
@ -4,15 +4,11 @@ The notebooks and scripts in these examples show how to fine-tune a model with a
|
||||
|
||||
Here's an overview of the notebooks and scripts in the [trl repository](https://github.com/huggingface/trl/tree/main/examples):
|
||||
|
||||
|
||||
|
||||
| File | Description |
|
||||
|------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------|
|
||||
| [`examples/scripts/ppo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo.py) [](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/sentiment/notebooks/gpt2-sentiment.ipynb) | This script shows how to use the `PPOTrainer` to fine-tune a sentiment analysis model using IMDB dataset |
|
||||
| [`examples/notebooks/gpt2-sentiment.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-sentiment.ipynb) | This notebook demonstrates how to reproduce the GPT2 imdb sentiment tuning example on a jupyter notebook. |
|
||||
| [`examples/notebooks/gpt2-control.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-control.ipynb) [](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/sentiment/notebooks/gpt2-sentiment-control.ipynb) | This notebook demonstrates how to reproduce the GPT2 sentiment control example on a jupyter notebook.
|
||||
|
||||
|
||||
| File | Description |
|
||||
| --- |--- |
|
||||
| [`examples/scripts/ppo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo.py) [](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/sentiment/notebooks/gpt2-sentiment.ipynb) | This script shows how to use the `PPOTrainer` to fine-tune a sentiment analysis model using IMDB dataset |
|
||||
| [`examples/notebooks/gpt2-sentiment.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-sentiment.ipynb) | This notebook demonstrates how to reproduce the GPT2 imdb sentiment tuning example on a jupyter notebook. |
|
||||
| [`examples/notebooks/gpt2-control.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-control.ipynb) [](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/sentiment/notebooks/gpt2-sentiment-control.ipynb) | This notebook demonstrates how to reproduce the GPT2 sentiment control example on a jupyter notebook. |
|
||||
|
||||
## Usage
|
||||
|
||||
@ -30,7 +26,6 @@ python examples/scripts/ppo.py --log_with wandb --mini_batch_size 1 --gradient_a
|
||||
|
||||
Note: if you don't want to log with `wandb` remove `log_with="wandb"` in the scripts/notebooks. You can also replace it with your favourite experiment tracker that's [supported by `accelerate`](https://huggingface.co/docs/accelerate/usage_guides/tracking).
|
||||
|
||||
## Few notes on multi-GPU
|
||||
|
||||
## Few notes on multi-GPU
|
||||
|
||||
To run in multi-GPU setup with DDP (distributed Data Parallel) change the `device_map` value to `device_map={"": Accelerator().process_index}` and make sure to run your script with `accelerate launch yourscript.py`. If you want to apply naive pipeline parallelism you can use `device_map="auto"`.
|
||||
To run in multi-GPU setup with DDP (distributed Data Parallel) change the `device_map` value to `device_map={"": Accelerator().process_index}` and make sure to run your script with `accelerate launch yourscript.py`. If you want to apply naive pipeline parallelism you can use `device_map="auto"`.
|
||||
|
@ -106,7 +106,6 @@ $$
|
||||
where \\( y_t \\) is the target token at timestep \\( t \\), and the model is trained to predict the next token given the previous ones. In practice, padding tokens are masked out during loss computation.
|
||||
|
||||
> [!TIP]
|
||||
>
|
||||
> [On the Generalization of SFT: A Reinforcement Learning Perspective with Reward Rectification](https://huggingface.co/papers/2508.05629) proposes an alternative loss function, called **Dynamic Fine-Tuning (DFT)**, which aims to improve generalization by rectifying the reward signal. This method can be enabled by setting `loss_type="dft"` in the [`SFTConfig`]. For more details, see [Paper Index - Dynamic Fine-Tuning](paper_index#on-the-generalization-of-sft-a-reinforcement-learning-perspective-with-reward-rectification).
|
||||
|
||||
### Label shifting and masking
|
||||
|
@ -48,11 +48,13 @@ You can customize the server configuration by passing additional arguments. For
|
||||
> When using vLLM, ensure that the GPUs assigned for training and generation are separate to avoid resource conflicts. For instance, if you plan to use 4 GPUs for training and another 4 for vLLM generation, you can specify GPU allocation using `CUDA_VISIBLE_DEVICES`.
|
||||
>
|
||||
> Set GPUs **0-3** for vLLM generation:
|
||||
>
|
||||
> ```sh
|
||||
> CUDA_VISIBLE_DEVICES=0,1,2,3 trl vllm-serve --model <model_name>
|
||||
> ```
|
||||
>
|
||||
> And GPUs **4-7** for training:
|
||||
> And GPUs **4-7** for training:
|
||||
>
|
||||
> ```sh
|
||||
> CUDA_VISIBLE_DEVICES=4,5,6,7 accelerate launch train.py
|
||||
> ```
|
||||
@ -79,12 +81,14 @@ You can customize the server configuration by passing additional arguments. For
|
||||
> [!WARNING]
|
||||
> When using vLLM, ensure that the GPUs assigned for training and generation are separate to avoid resource conflicts. For instance, if you plan to use 4 GPUs for training and another 4 for vLLM generation, you can specify GPU allocation using `CUDA_VISIBLE_DEVICES`.
|
||||
>
|
||||
> Set GPUs **0-3** for vLLM generation:
|
||||
> Set GPUs **0-3** for vLLM generation:
|
||||
>
|
||||
> ```sh
|
||||
> CUDA_VISIBLE_DEVICES=0,1,2,3 trl vllm-serve --model <model_name>
|
||||
> ```
|
||||
>
|
||||
> And GPUs **4-7** for training:
|
||||
> And GPUs **4-7** for training:
|
||||
>
|
||||
> ```sh
|
||||
> CUDA_VISIBLE_DEVICES=4,5,6,7 accelerate launch train.py
|
||||
> ```
|
||||
|
@ -43,7 +43,6 @@ To use the data efficiently, we use a technique called packing: instead of havin
|
||||
With this approach the training is much more efficient as each token that is passed through the model is also trained in contrast to padding tokens which are usually masked from the loss.
|
||||
If you don't have much data and are more concerned about occasionally cutting off some tokens that are overflowing the context you can also use a classical data loader.
|
||||
|
||||
|
||||
```python
|
||||
# load model in 8bit
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
@ -109,6 +108,7 @@ peft_config = LoraConfig(
|
||||
lora_dropout=0.1,
|
||||
)
|
||||
```
|
||||
|
||||
As detailed in the next section, the resulting adapter can be merged into the frozen model and saved for further downstream use.
|
||||
|
||||
## Reinforcement Learning from Human Feedback
|
||||
|
@ -7,14 +7,13 @@ This document will guide you through the process of using vLLM with TRL for fast
|
||||
|
||||
> [!TIP]
|
||||
> The following trainers currently support generation with vLLM:
|
||||
>
|
||||
>
|
||||
> - [`GRPOTrainer`]
|
||||
> - [`OnlineDPOTrainer`]
|
||||
> - [`NashMDTrainer`]
|
||||
> - [`XPOTrainer`]
|
||||
> - [`RLOOTrainer`]
|
||||
|
||||
|
||||
## 🚀 How can I use vLLM with TRL to speed up training?
|
||||
|
||||
💡 **Note**: Resources required for this specific example: a single node with 8 GPUs.
|
||||
@ -235,16 +234,16 @@ Separately, the number of completions to generate per prompt is controlled by th
|
||||
|
||||
### 🥸 More detail on what happens under the hood when running the server
|
||||
|
||||
* The vLLM server starts by running the command: `trl vllm-serve --model Qwen/Qwen2.5-7B`.
|
||||
* Once the server is running, it generates completions based on requests from the client (trainer) using `vllm_client.generate` [here](https://github.com/huggingface/trl/blob/cc044e35b285be7dc062764b3364e1e684db4c7c/trl/trainer/grpo_trainer.py#L1025-L1035).
|
||||
* The client (trainer) then requests these completions from the server.
|
||||
* These completions are used to compute the reward signal.
|
||||
* Based on the reward signal and the model’s output, the loss is computed, and the backward pass is performed to update the model’s weights.
|
||||
* **Note**: The server only handles completion generation — it doesn’t train the model. Therefore, the model’s weights aren’t updated on the server. Once the backward pass is complete, the client sends the updated weights to the server using `vllm_client.update_named_param(name, param.data)`.
|
||||
- The vLLM server starts by running the command: `trl vllm-serve --model Qwen/Qwen2.5-7B`.
|
||||
- Once the server is running, it generates completions based on requests from the client (trainer) using `vllm_client.generate` [these lines](https://github.com/huggingface/trl/blob/cc044e35b285be7dc062764b3364e1e684db4c7c/trl/trainer/grpo_trainer.py#L1025-L1035).
|
||||
- The client (trainer) then requests these completions from the server.
|
||||
- These completions are used to compute the reward signal.
|
||||
- Based on the reward signal and the model’s output, the loss is computed, and the backward pass is performed to update the model’s weights.
|
||||
- **Note**: The server only handles completion generation — it doesn’t train the model. Therefore, the model’s weights aren’t updated on the server. Once the backward pass is complete, the client sends the updated weights to the server using `vllm_client.update_named_param(name, param.data)`.
|
||||
|
||||
When using vLLM, ensure the GPUs assigned for training and generation are separate to avoid NCCL communication conflicts. If you do not set the `CUDA_VISIBLE_DEVICES` environment variable, the training script will use all available GPUs by default, which may lead to device conflicts. Starting from TRL next release after v0.19.1, the code automatically detects and prevents same-device usage, raising a error at the vllm server process:
|
||||
|
||||
```
|
||||
```log
|
||||
RuntimeError: Attempting to use the same CUDA device for multiple distinct roles/ranks within the same communicator.
|
||||
Ensure that trainer is using different devices than vLLM server.
|
||||
```
|
||||
@ -307,23 +306,23 @@ options:
|
||||
|
||||
### 💆🏻♀️ What's the best distributed setup?
|
||||
|
||||

|
||||

|
||||

|
||||

|
||||
|
||||
First and foremost, always remember that the optimal setup depends on:
|
||||
|
||||
* The model size
|
||||
* The number of GPUs you have
|
||||
* The GPU memory size
|
||||
* The batch size you are using
|
||||
* The number of requests you are sending to the server (prompts)
|
||||
* The `max_model_len` you are using (this is the max length of the input sequence that the model can process, a.k.a. the context window size)
|
||||
* The number of completions you are generating for each request (`num_generations`)
|
||||
- The model size
|
||||
- The number of GPUs you have
|
||||
- The GPU memory size
|
||||
- The batch size you are using
|
||||
- The number of requests you are sending to the server (prompts)
|
||||
- The `max_model_len` you are using (this is the max length of the input sequence that the model can process, a.k.a. the context window size)
|
||||
- The number of completions you are generating for each request (`num_generations`)
|
||||
|
||||
Given these factors, our experiments on the Qwen model family (3B, 7B, 14B, 32B) using 8 H100 GPUs show that:
|
||||
|
||||
* For reasonable-sized models (3B–14B) and a moderate context window (`max_len < 8k`), using full capacity for data parallelism gives better throughput. The setup `(tp=1, dp=8)` yields the best results.
|
||||
* For larger models (32B) and longer context windows (`max_len > 8k`), a smaller DP size combined with some model-side parallelism performs better. For example, `(tp=2, dp=4)` is a good setup for 32B models with a larger context window.
|
||||
- For reasonable-sized models (3B–14B) and a moderate context window (`max_len < 8k`), using full capacity for data parallelism gives better throughput. The setup `(tp=1, dp=8)` yields the best results.
|
||||
- For larger models (32B) and longer context windows (`max_len > 8k`), a smaller DP size combined with some model-side parallelism performs better. For example, `(tp=2, dp=4)` is a good setup for 32B models with a larger context window.
|
||||
|
||||
### vLLM with Transformers Backend
|
||||
|
||||
@ -334,7 +333,7 @@ For more details, check out [vLLM Transformers Backend](https://blog.vllm.ai/202
|
||||
|
||||
Example:
|
||||
|
||||
```
|
||||
```sh
|
||||
CUDA_DEVICE_ORDER=PCI_BUS_ID CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen
|
||||
2.5-VL-3B-Instruct --tensor-parallel-size 1 --port 8000 --enforce_eager --vllm_model_impl transformers
|
||||
```
|
||||
@ -496,7 +495,5 @@ training_args = RLOOConfig(
|
||||
> [!WARNING]
|
||||
> Check the documentation of the trainer you are using for specific details on vLLM usage and parameters.
|
||||
|
||||
|
||||
> [!WARNING]
|
||||
> To reduce GPU memory usage when running vLLM, consider [enabling vLLM sleep mode](reducing_memory_usage#vllm-sleep-mode).
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
# XPO Trainer
|
||||
|
||||
[](https://huggingface.co/models?other=xpo,trl)
|
||||
[](https://huggingface.co/models?other=xpo,trl)
|
||||
|
||||
## Overview
|
||||
|
||||
@ -57,7 +57,7 @@ To see how the [trained model](https://huggingface.co/trl-lib/Qwen2-0.5B-XPO) pe
|
||||
What is the best programming language?
|
||||
|
||||
<strong><span style="color: blue;"><trl-lib/Qwen2-0.5B-XPO>:</span></strong>
|
||||
The best programming language depends on individual preferences and familiarity with coding concepts. Some popular languages include Python, Java, C++, and JavaScript.
|
||||
The best programming language depends on individual preferences and familiarity with coding concepts. Some popular languages include Python, Java, C++, and JavaScript.
|
||||
</code></pre>
|
||||
|
||||
## Expected dataset type
|
||||
@ -148,7 +148,6 @@ While training and evaluating we record the following reward metrics:
|
||||
* `alpha`: The weight of the XPO loss term. Typically fixed, but can be made dynamic by passing a list to [`XPOConfig`].
|
||||
* `beta`: The parameter that controls the weight of the loss term representing the deviation from the reference model. Typically fixed, but can be made dynamic by passing a list to [`XPOConfig`].
|
||||
|
||||
|
||||
## XPOTrainer
|
||||
|
||||
[[autodoc]] XPOTrainer
|
||||
|
@ -1,3 +1,3 @@
|
||||
# Examples
|
||||
|
||||
Please check out https://huggingface.co/docs/trl/example_overview for documentation on our examples.
|
||||
Please check out https://huggingface.co/docs/trl/example_overview for documentation on our examples.
|
||||
|
@ -2,6 +2,5 @@
|
||||
|
||||
This directory contains a collection of Jupyter notebooks that demonstrate how to use the TRL library in different applications.
|
||||
|
||||
- [`best_of_n.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/best_of_n.ipynb): This notebook demonstrates how to use the "Best of N" sampling strategy using TRL when fine-tuning your model with PPO.
|
||||
- [`gpt2-sentiment.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-sentiment.ipynb): This notebook demonstrates how to reproduce the GPT2 imdb sentiment tuning example on a jupyter notebook.
|
||||
- [`gpt2-sentiment-control.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-sentiment-control.ipynb): This notebook demonstrates how to reproduce the GPT2 sentiment control example on a jupyter notebook.
|
||||
|
@ -1,609 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "WQpNapZNWuXP"
|
||||
},
|
||||
"source": [
|
||||
"\n",
|
||||
"**Best-of-n sampling as an alternative to RLHF**\n",
|
||||
"\n",
|
||||
"This notebook compares reward-model scores of prompt based responses from \n",
|
||||
"1. a base model (`gpt2-imdb`)\n",
|
||||
"2. `RLHF` tuned model based on this base-model \n",
|
||||
"3. the base-model again from which we sample n responses to each prompt, score them and take the best scored one AKA the `best-of-n sampled` model\n",
|
||||
"\n",
|
||||
"Import dependencies"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "vDA6qayz692w"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%pip install transformers trl"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "M1s_iNm773hM"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"import pandas as pd\n",
|
||||
"\n",
|
||||
"from transformers import pipeline, AutoTokenizer\n",
|
||||
"from datasets import load_dataset\n",
|
||||
"\n",
|
||||
"from trl import AutoModelForCausalLMWithValueHead\n",
|
||||
"from trl.core import LengthSampler\n",
|
||||
"\n",
|
||||
"device = torch.accelerator.current_accelerator().type if hasattr(torch, \"accelerator\") else \"cuda\"\n",
|
||||
"device = \"cpu\" if device is None else device"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "Y7hyrIrO8tcY"
|
||||
},
|
||||
"source": [
|
||||
"Various constants"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {
|
||||
"id": "MqS3OM6Q8x6g"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"ref_model_name = \"lvwerra/gpt2-imdb\"\n",
|
||||
"model_name = \"lvwerra/gpt2-imdb-pos-v2\"\n",
|
||||
"reward_model = \"lvwerra/distilbert-imdb\"\n",
|
||||
"\n",
|
||||
"N_BEST_OF = 4"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "c1YcXeElg6or"
|
||||
},
|
||||
"source": [
|
||||
"Models and tokenizers"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "b855NrL181Hh"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model = AutoModelForCausalLMWithValueHead.from_pretrained(model_name)\n",
|
||||
"\n",
|
||||
"ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(ref_model_name)\n",
|
||||
"\n",
|
||||
"reward_pipe = pipeline(\"sentiment-analysis\", model=reward_model, device=device)\n",
|
||||
"\n",
|
||||
"tokenizer = AutoTokenizer.from_pretrained(ref_model_name)\n",
|
||||
"\n",
|
||||
"tokenizer.pad_token = tokenizer.eos_token\n",
|
||||
"\n",
|
||||
"# put models to accelerator\n",
|
||||
"model.to(device)\n",
|
||||
"ref_model.to(device)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "Z1Cz0gCFhZYJ"
|
||||
},
|
||||
"source": [
|
||||
"Dataset building"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {
|
||||
"id": "LqLVEp5p_8XM"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Generating train split: 100%|██████████| 25000/25000 [00:00<00:00, 113700.67 examples/s]\n",
|
||||
"Generating test split: 100%|██████████| 25000/25000 [00:00<00:00, 131049.39 examples/s]\n",
|
||||
"Generating unsupervised split: 100%|██████████| 50000/50000 [00:00<00:00, 126486.39 examples/s]\n",
|
||||
"Filter: 100%|██████████| 25000/25000 [00:00<00:00, 238843.61 examples/s]\n",
|
||||
"Map: 0%| | 0/24895 [00:00<?, ? examples/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1168 > 1024). Running this sequence through the model will result in indexing errors\n",
|
||||
"Map: 100%|██████████| 24895/24895 [00:17<00:00, 1462.36 examples/s]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"def build_dataset(\n",
|
||||
" tokenizer,\n",
|
||||
" dataset_name=\"stanfordnlp/imdb\",\n",
|
||||
" input_min_text_length=2,\n",
|
||||
" input_max_text_length=8,\n",
|
||||
"):\n",
|
||||
" # load imdb with datasets\n",
|
||||
" ds = load_dataset(dataset_name, split=\"train\")\n",
|
||||
" ds = ds.rename_columns({\"text\": \"review\"})\n",
|
||||
" ds = ds.filter(lambda x: len(x[\"review\"]) > 200, batched=False)\n",
|
||||
"\n",
|
||||
" input_size = LengthSampler(input_min_text_length, input_max_text_length)\n",
|
||||
"\n",
|
||||
" def tokenize(sample):\n",
|
||||
" sample[\"input_ids\"] = tokenizer.encode(sample[\"review\"])[: input_size()]\n",
|
||||
" sample[\"query\"] = tokenizer.decode(sample[\"input_ids\"])\n",
|
||||
" return sample\n",
|
||||
"\n",
|
||||
" ds = ds.map(tokenize, batched=False)\n",
|
||||
" ds.set_format(type=\"torch\")\n",
|
||||
" return ds\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"dataset = build_dataset(tokenizer)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {
|
||||
"id": "AqA2McjMAxNw"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"gen_kwargs = {\n",
|
||||
" \"min_length\": -1,\n",
|
||||
" \"top_k\": 0.0,\n",
|
||||
" \"top_p\": 1.0,\n",
|
||||
" \"do_sample\": True,\n",
|
||||
" \"pad_token_id\": tokenizer.eos_token_id,\n",
|
||||
"}\n",
|
||||
"sent_kwargs = {\"top_k\": None, \"function_to_apply\": \"none\", \"batch_size\": 16}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {
|
||||
"id": "L_q4qs35AxcR"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"output_min_length = 4\n",
|
||||
"output_max_length = 16\n",
|
||||
"output_length_sampler = LengthSampler(output_min_length, output_max_length)\n",
|
||||
"\n",
|
||||
"#### get a batch from the dataset\n",
|
||||
"bs = 16\n",
|
||||
"output_data = dict()\n",
|
||||
"dataset.set_format(\"pandas\")\n",
|
||||
"df_batch = dataset[:].sample(bs)\n",
|
||||
"output_data[\"query\"] = df_batch[\"query\"].tolist()\n",
|
||||
"query_tensors = df_batch[\"input_ids\"].tolist()\n",
|
||||
"\n",
|
||||
"# :: [Resp]\n",
|
||||
"response_tensors_ref, response_tensors = [], []\n",
|
||||
"# :: [[Resp]]\n",
|
||||
"response_tensors_best_of = []"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "QVfpyHnZBLKY"
|
||||
},
|
||||
"source": [
|
||||
"\n",
|
||||
"Generation using various models"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "-imZ7uEFBNbw"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"for i in range(bs):\n",
|
||||
" gen_len = output_length_sampler()\n",
|
||||
"\n",
|
||||
" query = torch.tensor(query_tensors[i])\n",
|
||||
"\n",
|
||||
" output = ref_model.generate(\n",
|
||||
" query.unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs\n",
|
||||
" ).squeeze()\n",
|
||||
" response_tensors_ref.append(tokenizer.decode(output))\n",
|
||||
"\n",
|
||||
" output = model.generate(\n",
|
||||
" query.unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs\n",
|
||||
" ).squeeze()\n",
|
||||
" response_tensors.append(tokenizer.decode(output))\n",
|
||||
"\n",
|
||||
" # generating copies of the same query for the Best-of-n sampling\n",
|
||||
" queries = query.repeat((N_BEST_OF, 1))\n",
|
||||
" output = ref_model.generate(\n",
|
||||
" queries.to(device), max_new_tokens=gen_len, **gen_kwargs\n",
|
||||
" ).squeeze()\n",
|
||||
" response_tensors_best_of.append(tokenizer.batch_decode(output))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "Jp5FC0Y5h_Sf"
|
||||
},
|
||||
"source": [
|
||||
"Scoring"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {
|
||||
"id": "PyDbbAQ0F_h7"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"scores_ref = [\n",
|
||||
" output[0][\"score\"] for output in reward_pipe(response_tensors_ref, **sent_kwargs)\n",
|
||||
"]\n",
|
||||
"scores = [output[0][\"score\"] for output in reward_pipe(response_tensors, **sent_kwargs)]\n",
|
||||
"scores_best_of = []\n",
|
||||
"for i, response in enumerate(response_tensors_best_of):\n",
|
||||
" # base_score = scores_ref[i]\n",
|
||||
" scores_best_of.append(\n",
|
||||
" torch.tensor(\n",
|
||||
" [output[0][\"score\"] for output in reward_pipe(response, **sent_kwargs)]\n",
|
||||
" )\n",
|
||||
" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/",
|
||||
"height": 682
|
||||
},
|
||||
"id": "nA1GDNJEiGm-",
|
||||
"outputId": "1389c686-0751-4304-dea2-b71fd68748e1"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/html": [
|
||||
"<div>\n",
|
||||
"<style scoped>\n",
|
||||
" .dataframe tbody tr th:only-of-type {\n",
|
||||
" vertical-align: middle;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe tbody tr th {\n",
|
||||
" vertical-align: top;\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" .dataframe thead th {\n",
|
||||
" text-align: right;\n",
|
||||
" }\n",
|
||||
"</style>\n",
|
||||
"<table border=\"1\" class=\"dataframe\">\n",
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>query</th>\n",
|
||||
" <th>response (ref)</th>\n",
|
||||
" <th>scores (ref)</th>\n",
|
||||
" <th>response (RLHF)</th>\n",
|
||||
" <th>scores (RLHF)</th>\n",
|
||||
" <th>response (best_of)</th>\n",
|
||||
" <th>scores (best_of)</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>0</th>\n",
|
||||
" <td>This movie is one of</td>\n",
|
||||
" <td>This movie is one of the most twisted films I</td>\n",
|
||||
" <td>2.094254</td>\n",
|
||||
" <td>This movie is one of the finest directors of the</td>\n",
|
||||
" <td>2.726879</td>\n",
|
||||
" <td>This movie is one of the best looking movies I</td>\n",
|
||||
" <td>2.705925</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1</th>\n",
|
||||
" <td>one may</td>\n",
|
||||
" <td>one may feel we are seeing more</td>\n",
|
||||
" <td>1.478813</td>\n",
|
||||
" <td>one may not have great assets,</td>\n",
|
||||
" <td>0.420451</td>\n",
|
||||
" <td>one may not be supported, terrible</td>\n",
|
||||
" <td>2.043730</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>2</th>\n",
|
||||
" <td>This is an amazing film,</td>\n",
|
||||
" <td>This is an amazing film, one of our favorite g...</td>\n",
|
||||
" <td>2.871389</td>\n",
|
||||
" <td>This is an amazing film, with all thelike wond...</td>\n",
|
||||
" <td>2.918770</td>\n",
|
||||
" <td>This is an amazing film, very moving and this ...</td>\n",
|
||||
" <td>2.871694</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>3</th>\n",
|
||||
" <td>just below</td>\n",
|
||||
" <td>just below)and makes it seem as</td>\n",
|
||||
" <td>0.861618</td>\n",
|
||||
" <td>just below the world capital is a man</td>\n",
|
||||
" <td>0.238322</td>\n",
|
||||
" <td>just below) in this beautiful comedy.</td>\n",
|
||||
" <td>2.760033</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>4</th>\n",
|
||||
" <td>Return To the</td>\n",
|
||||
" <td>Return To the Museum. That film, called Bl</td>\n",
|
||||
" <td>0.017376</td>\n",
|
||||
" <td>Return To the East\" is a fascinating film,</td>\n",
|
||||
" <td>2.648028</td>\n",
|
||||
" <td>Return To the International: Miyazaki, by Ts</td>\n",
|
||||
" <td>1.072344</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>5</th>\n",
|
||||
" <td>Brando plays the ace jet</td>\n",
|
||||
" <td>Brando plays the ace jet fighter pilot, who stops</td>\n",
|
||||
" <td>0.565335</td>\n",
|
||||
" <td>Brando plays the ace jet pilot, who's a</td>\n",
|
||||
" <td>0.668954</td>\n",
|
||||
" <td>Brando plays the ace jet pilot Charlie; his fo...</td>\n",
|
||||
" <td>0.679582</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>6</th>\n",
|
||||
" <td>And a rather U</td>\n",
|
||||
" <td>And a rather Utopian horror movie and with good</td>\n",
|
||||
" <td>2.245751</td>\n",
|
||||
" <td>And a rather Utop Congressional Movie, with a 45</td>\n",
|
||||
" <td>0.307100</td>\n",
|
||||
" <td>And a rather U of A complete combination of wh...</td>\n",
|
||||
" <td>2.209265</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>7</th>\n",
|
||||
" <td>The plot of this movie hangs</td>\n",
|
||||
" <td>The plot of this movie hangs in the balance as...</td>\n",
|
||||
" <td>1.122540</td>\n",
|
||||
" <td>The plot of this movie hangs out well. The who...</td>\n",
|
||||
" <td>2.195263</td>\n",
|
||||
" <td>The plot of this movie hangs together within t...</td>\n",
|
||||
" <td>1.310783</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>8</th>\n",
|
||||
" <td>This isn't</td>\n",
|
||||
" <td>This isn't all that bad; as for my</td>\n",
|
||||
" <td>0.623968</td>\n",
|
||||
" <td>This isn't a good film because I loved it</td>\n",
|
||||
" <td>1.694601</td>\n",
|
||||
" <td>This isn't bad writing, powerful actors and sp...</td>\n",
|
||||
" <td>1.835901</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>9</th>\n",
|
||||
" <td>This movie was for a</td>\n",
|
||||
" <td>This movie was for a good reason!' Uh, OK</td>\n",
|
||||
" <td>0.437566</td>\n",
|
||||
" <td>This movie was for a fun, and grand Robinson</td>\n",
|
||||
" <td>2.531890</td>\n",
|
||||
" <td>This movie was for a bastard.<br /><br</td>\n",
|
||||
" <td>2.311337</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>10</th>\n",
|
||||
" <td>witty. funny.</td>\n",
|
||||
" <td>witty. funny.<|endoftext|></td>\n",
|
||||
" <td>1.636344</td>\n",
|
||||
" <td>witty. funny. funnier. more funny. funnier. fu...</td>\n",
|
||||
" <td>2.132353</td>\n",
|
||||
" <td>witty. funny. In the first scene the comical n...</td>\n",
|
||||
" <td>2.164077</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>11</th>\n",
|
||||
" <td>It's very hard</td>\n",
|
||||
" <td>It's very hard to believe that anyone would en...</td>\n",
|
||||
" <td>1.003727</td>\n",
|
||||
" <td>It's very hard to wrap your mind around what h...</td>\n",
|
||||
" <td>0.778888</td>\n",
|
||||
" <td>It's very hard to wrap this up, due to lack of...</td>\n",
|
||||
" <td>1.598843</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>12</th>\n",
|
||||
" <td>Absolutely fantastic trash....this one</td>\n",
|
||||
" <td>Absolutely fantastic trash....this one was hav...</td>\n",
|
||||
" <td>1.350834</td>\n",
|
||||
" <td>Absolutely fantastic trash....this one is a pe...</td>\n",
|
||||
" <td>2.177587</td>\n",
|
||||
" <td>Absolutely fantastic trash....this one ruins i...</td>\n",
|
||||
" <td>2.221997</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>13</th>\n",
|
||||
" <td>Prior to</td>\n",
|
||||
" <td>Prior to this action film,</td>\n",
|
||||
" <td>0.242474</td>\n",
|
||||
" <td>Prior to Christian Kane's star</td>\n",
|
||||
" <td>0.297408</td>\n",
|
||||
" <td>Prior to his restoration, Passion</td>\n",
|
||||
" <td>1.655534</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>14</th>\n",
|
||||
" <td>i,</td>\n",
|
||||
" <td>i, Marty Rathbun, Damon Wayans, Mark Watney and</td>\n",
|
||||
" <td>0.105734</td>\n",
|
||||
" <td>i, perhaps the great movie the director should...</td>\n",
|
||||
" <td>1.336116</td>\n",
|
||||
" <td>i, Martin was a thrill of 70s---wow!lee and Heath</td>\n",
|
||||
" <td>2.277638</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>15</th>\n",
|
||||
" <td>The film</td>\n",
|
||||
" <td>The film takes a very grim craggy look</td>\n",
|
||||
" <td>0.069017</td>\n",
|
||||
" <td>The film is one of the best of that era</td>\n",
|
||||
" <td>2.737825</td>\n",
|
||||
" <td>The film's ambition was almost so great that its</td>\n",
|
||||
" <td>2.357480</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" query \\\n",
|
||||
"0 This movie is one of \n",
|
||||
"1 one may \n",
|
||||
"2 This is an amazing film, \n",
|
||||
"3 just below \n",
|
||||
"4 Return To the \n",
|
||||
"5 Brando plays the ace jet \n",
|
||||
"6 And a rather U \n",
|
||||
"7 The plot of this movie hangs \n",
|
||||
"8 This isn't \n",
|
||||
"9 This movie was for a \n",
|
||||
"10 witty. funny. \n",
|
||||
"11 It's very hard \n",
|
||||
"12 Absolutely fantastic trash....this one \n",
|
||||
"13 Prior to \n",
|
||||
"14 i, \n",
|
||||
"15 The film \n",
|
||||
"\n",
|
||||
" response (ref) scores (ref) \\\n",
|
||||
"0 This movie is one of the most twisted films I 2.094254 \n",
|
||||
"1 one may feel we are seeing more 1.478813 \n",
|
||||
"2 This is an amazing film, one of our favorite g... 2.871389 \n",
|
||||
"3 just below)and makes it seem as 0.861618 \n",
|
||||
"4 Return To the Museum. That film, called Bl 0.017376 \n",
|
||||
"5 Brando plays the ace jet fighter pilot, who stops 0.565335 \n",
|
||||
"6 And a rather Utopian horror movie and with good 2.245751 \n",
|
||||
"7 The plot of this movie hangs in the balance as... 1.122540 \n",
|
||||
"8 This isn't all that bad; as for my 0.623968 \n",
|
||||
"9 This movie was for a good reason!' Uh, OK 0.437566 \n",
|
||||
"10 witty. funny.<|endoftext|> 1.636344 \n",
|
||||
"11 It's very hard to believe that anyone would en... 1.003727 \n",
|
||||
"12 Absolutely fantastic trash....this one was hav... 1.350834 \n",
|
||||
"13 Prior to this action film, 0.242474 \n",
|
||||
"14 i, Marty Rathbun, Damon Wayans, Mark Watney and 0.105734 \n",
|
||||
"15 The film takes a very grim craggy look 0.069017 \n",
|
||||
"\n",
|
||||
" response (RLHF) scores (RLHF) \\\n",
|
||||
"0 This movie is one of the finest directors of the 2.726879 \n",
|
||||
"1 one may not have great assets, 0.420451 \n",
|
||||
"2 This is an amazing film, with all thelike wond... 2.918770 \n",
|
||||
"3 just below the world capital is a man 0.238322 \n",
|
||||
"4 Return To the East\" is a fascinating film, 2.648028 \n",
|
||||
"5 Brando plays the ace jet pilot, who's a 0.668954 \n",
|
||||
"6 And a rather Utop Congressional Movie, with a 45 0.307100 \n",
|
||||
"7 The plot of this movie hangs out well. The who... 2.195263 \n",
|
||||
"8 This isn't a good film because I loved it 1.694601 \n",
|
||||
"9 This movie was for a fun, and grand Robinson 2.531890 \n",
|
||||
"10 witty. funny. funnier. more funny. funnier. fu... 2.132353 \n",
|
||||
"11 It's very hard to wrap your mind around what h... 0.778888 \n",
|
||||
"12 Absolutely fantastic trash....this one is a pe... 2.177587 \n",
|
||||
"13 Prior to Christian Kane's star 0.297408 \n",
|
||||
"14 i, perhaps the great movie the director should... 1.336116 \n",
|
||||
"15 The film is one of the best of that era 2.737825 \n",
|
||||
"\n",
|
||||
" response (best_of) scores (best_of) \n",
|
||||
"0 This movie is one of the best looking movies I 2.705925 \n",
|
||||
"1 one may not be supported, terrible 2.043730 \n",
|
||||
"2 This is an amazing film, very moving and this ... 2.871694 \n",
|
||||
"3 just below) in this beautiful comedy. 2.760033 \n",
|
||||
"4 Return To the International: Miyazaki, by Ts 1.072344 \n",
|
||||
"5 Brando plays the ace jet pilot Charlie; his fo... 0.679582 \n",
|
||||
"6 And a rather U of A complete combination of wh... 2.209265 \n",
|
||||
"7 The plot of this movie hangs together within t... 1.310783 \n",
|
||||
"8 This isn't bad writing, powerful actors and sp... 1.835901 \n",
|
||||
"9 This movie was for a bastard.<br /><br 2.311337 \n",
|
||||
"10 witty. funny. In the first scene the comical n... 2.164077 \n",
|
||||
"11 It's very hard to wrap this up, due to lack of... 1.598843 \n",
|
||||
"12 Absolutely fantastic trash....this one ruins i... 2.221997 \n",
|
||||
"13 Prior to his restoration, Passion 1.655534 \n",
|
||||
"14 i, Martin was a thrill of 70s---wow!lee and Heath 2.277638 \n",
|
||||
"15 The film's ambition was almost so great that its 2.357480 "
|
||||
]
|
||||
},
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"output_data[\"response (ref)\"] = response_tensors_ref\n",
|
||||
"output_data[\"scores (ref)\"] = scores_ref\n",
|
||||
"output_data[\"response (RLHF)\"] = response_tensors\n",
|
||||
"output_data[\"scores (RLHF)\"] = scores\n",
|
||||
"output_data[\"response (best_of)\"] = [\n",
|
||||
" response_tensors_best_of[i][a.argmax().item()] for i, a in enumerate(scores_best_of)\n",
|
||||
"]\n",
|
||||
"output_data[\"scores (best_of)\"] = [a.max().item() for a in scores_best_of]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# store results in a dataframe\n",
|
||||
"df_results = pd.DataFrame(output_data)\n",
|
||||
"df_results"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"accelerator": "GPU",
|
||||
"colab": {
|
||||
"provenance": []
|
||||
},
|
||||
"gpuClass": "standard",
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 1
|
||||
}
|
@ -1,7 +0,0 @@
|
||||
# Research projects that use TRL
|
||||
|
||||
Welcome to the research projects folder! Here you can find the scripts used for some research projects that used TRL and maintained by the developers and the community (LM de-toxification, Stack-Llama, etc.). Check out the READMEs in the subfolders for more information!
|
||||
|
||||
- [De-detoxifying language models](https://github.com/huggingface/trl/tree/main/examples/research_projects/toxicity)
|
||||
- [Stack-Llama](https://github.com/huggingface/trl/tree/main/examples/research_projects/stack_llama)
|
||||
- [Stack-Llama-2](https://github.com/huggingface/trl/tree/main/examples/research_projects/stack_llama_2)
|
@ -1,15 +0,0 @@
|
||||
# LayerSkip Training Recipe
|
||||
|
||||
Implements the training recipe as described in the [LayerSkip paper](https://huggingface.co/papers/2404.16710).
|
||||
|
||||
## Run training
|
||||
```
|
||||
cd scripts
|
||||
python layer_skip_sft.py
|
||||
```
|
||||
|
||||
## Run benchmark
|
||||
```
|
||||
cd scripts
|
||||
python benchmark_layer_skip.py
|
||||
```
|
@ -1,77 +0,0 @@
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import config
|
||||
import torch
|
||||
from torch.utils import benchmark
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
|
||||
def generate_tokens(model, inputs):
|
||||
outputs = model.generate(
|
||||
**inputs,
|
||||
do_sample=False,
|
||||
max_new_tokens=64,
|
||||
)
|
||||
return outputs
|
||||
|
||||
|
||||
def generate_tokens_with_assistance(model, inputs, assistant_early_exit):
|
||||
outputs = model.generate(
|
||||
**inputs,
|
||||
assistant_early_exit=assistant_early_exit,
|
||||
do_sample=False,
|
||||
max_new_tokens=64,
|
||||
)
|
||||
return outputs
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ckpt = config.hub_model_id
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(ckpt, device_map="auto", dtype=torch.bfloat16)
|
||||
tokenizer = AutoTokenizer.from_pretrained(ckpt)
|
||||
|
||||
prompt = "### Instruction: What are my alarms for the rest of the day?\n ### Response: "
|
||||
|
||||
results = []
|
||||
label = "Generation Times"
|
||||
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
||||
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt="generate_tokens(model, inputs)",
|
||||
setup="from __main__ import generate_tokens",
|
||||
globals={"model": model, "inputs": inputs},
|
||||
num_threads=torch.get_num_threads(),
|
||||
label=label,
|
||||
sub_label="no layer skip",
|
||||
description="generation",
|
||||
).blocked_autorange()
|
||||
)
|
||||
|
||||
for i in range(1, model.config.num_hidden_layers):
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt="generate_tokens_with_assistance(model, inputs, assistant_early_exit)",
|
||||
setup="from __main__ import generate_assistant_tokens",
|
||||
globals={"model": model, "assistant_early_exit": i, "inputs": inputs},
|
||||
num_threads=torch.get_num_threads(),
|
||||
label=label,
|
||||
sub_label=f"layer skip {i}",
|
||||
description="generation",
|
||||
).blocked_autorange()
|
||||
)
|
||||
|
||||
benchmark.Compare(results).print()
|
@ -1,48 +0,0 @@
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from trl import SFTTrainer
|
||||
|
||||
|
||||
class LayerSkipSFTTrainer(SFTTrainer):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.early_exit_layer = 0 # initialize with 0
|
||||
self.always_last_layer = True
|
||||
self.early_exit_loss_scale = 1.0
|
||||
|
||||
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
|
||||
self.early_exit_layer = (
|
||||
self.early_exit_layer % (model.config.num_hidden_layers - 1)
|
||||
) + 1 # rotates between [1, num_hidden_layers-1]
|
||||
bs, seqlen = inputs.input_ids.shape
|
||||
|
||||
labels = inputs.pop("labels")
|
||||
outputs = model(**inputs, output_hidden_states=True)
|
||||
|
||||
hidden_state = outputs["hidden_states"][self.early_exit_layer].to(model.dtype)
|
||||
if self.early_exit_layer != model.config.num_hidden_layers:
|
||||
hidden_state = model.model.norm(hidden_state)
|
||||
logits = model.lm_head(hidden_state)
|
||||
loss_early = model.loss_function(logits=logits, labels=labels, vocab_size=model.vocab_size)
|
||||
|
||||
if self.always_last_layer:
|
||||
loss_last = model.loss_function(logits=outputs["logits"], labels=labels, vocab_size=model.vocab_size)
|
||||
loss = self.early_exit_loss_scale * loss_early.to(loss_last.device) + 1.0 * loss_last
|
||||
# normalize loss scales
|
||||
loss = loss / (1.0 + self.early_exit_loss_scale)
|
||||
else:
|
||||
loss = loss_early
|
||||
|
||||
return loss
|
@ -1,90 +0,0 @@
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import config
|
||||
import torch
|
||||
from custom_trainer import LayerSkipSFTTrainer
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from trl import DataCollatorForCompletionOnlyLM, SFTConfig
|
||||
|
||||
|
||||
def formatting_prompts_func(example):
|
||||
text = f"### Instruction: {example['utterance']}\n ### Response: {example['semantic_parse']}"
|
||||
|
||||
# Inject eos_token as a string before tokenization, because they are not always added
|
||||
# See: https://github.com/huggingface/transformers/issues/22794 and
|
||||
# https://github.com/huggingface/trl/issues/1623
|
||||
if tokenizer.eos_token: # usually something like "</s>" for GPT2 or "<|endoftext|>"
|
||||
text += f"{tokenizer.eos_token}"
|
||||
|
||||
return text
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# load the dataset
|
||||
print("[INFO] loading the dataset...")
|
||||
train_dataset = load_dataset(config.dataset_name, split="train")
|
||||
|
||||
print(f"output_root_dir: {config.output_root_dir}")
|
||||
print(f"hub_model_id: {config.hub_model_id}")
|
||||
|
||||
# load the model and tokenizer
|
||||
print("[INFO] loading the model and tokenizer...")
|
||||
model = AutoModelForCausalLM.from_pretrained(config.model_name, device_map="auto", dtype=torch.bfloat16)
|
||||
tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name, add_eos_token=True)
|
||||
|
||||
# adding pad and eos tokens if not provided in the tokenizer
|
||||
if tokenizer.pad_token is None:
|
||||
# Add '[PAD]' token if it doesn't exist
|
||||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||
model.resize_token_embeddings(len(tokenizer))
|
||||
model.config.pad_token_id = tokenizer.pad_token_id
|
||||
|
||||
if tokenizer.eos_token is None or tokenizer.eos_token == tokenizer.bos_token:
|
||||
# Add '[EOS]' token if it doesn't exist
|
||||
tokenizer.add_special_tokens({"eos_token": "[EOS]"})
|
||||
model.resize_token_embeddings(len(tokenizer))
|
||||
model.config.eos_token_id = tokenizer.eos_token_id
|
||||
|
||||
response_template = " ### Response:"
|
||||
collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)
|
||||
|
||||
args = SFTConfig(
|
||||
do_train=True,
|
||||
bf16=True,
|
||||
max_seq_length=None,
|
||||
per_device_train_batch_size=config.per_device_train_batch_size,
|
||||
gradient_accumulation_steps=config.gradient_accumulation_steps,
|
||||
learning_rate=config.learning_rate,
|
||||
packing=False,
|
||||
num_train_epochs=1.0,
|
||||
report_to="none",
|
||||
push_to_hub=True,
|
||||
hub_model_id=config.hub_model_id,
|
||||
output_dir=config.output_dir,
|
||||
save_steps=1000,
|
||||
save_total_limit=2,
|
||||
)
|
||||
|
||||
trainer = LayerSkipSFTTrainer(
|
||||
model,
|
||||
train_dataset=train_dataset,
|
||||
args=args,
|
||||
formatting_func=formatting_prompts_func,
|
||||
data_collator=collator,
|
||||
)
|
||||
|
||||
trainer.train()
|
@ -1,18 +0,0 @@
|
||||
# RLHF pipeline for the creation of StackLLaMa: a Stack exchange llama-7b model.
|
||||
There were three main steps to the training process:
|
||||
1. Supervised fine-tuning of the base llama-7b model to create llama-7b-se:
|
||||
- `torchrun --nnodes 1 --nproc_per_node 8 examples/research_projects/stack_llama/scripts/supervised_finetuning.py --model_path=<LLAMA_MODEL_PATH> --streaming --learning_rate 1e-5 --max_steps 5000 --output_dir ./llama-se`
|
||||
2. Reward modeling using dialog pairs from the SE dataset using the llama-7b-se to create llama-7b-se-rm:
|
||||
- `torchrun --nnodes 1 --nproc_per_node 8 examples/research_projects/stack_llama/scripts/reward_modeling.py --model_name=<LLAMA_SE_MODEL>`
|
||||
3. RL fine-tuning of llama-7b-se with the llama-7b-se-rm reward model:
|
||||
- `accelerate launch --multi_gpu --num_machines 1 --num_processes 8 examples/research_projects/stack_llama/scripts/rl_training.py --log_with=wandb --model_name=<LLAMA_SE_MODEL> --reward_model_name=<LLAMA_SE_RM_MODEL> --adafactor=False --tokenizer_name=<LLAMA_TOKENIZER> --save_freq=100 --output_max_length=128 --batch_size=8 --gradient_accumulation_steps=8 --batched_gen=True --ppo_epochs=4 --seed=0 --learning_rate=1.4e-5 --early_stopping=True --output_dir=llama-se-rl-finetune-128-8-8-1.4e-5_adam`
|
||||
|
||||
|
||||
LoRA layers were using at all stages to reduce memory requirements.
|
||||
At each stage the peft adapter layers were merged with the base model, using:
|
||||
```shell
|
||||
python examples/research_projects/stack_llama/scripts/merge_peft_adapter.py --adapter_model_name=XXX --base_model_name=YYY --output_name=ZZZ
|
||||
```
|
||||
Note that this script requires `peft>=0.3.0`.
|
||||
|
||||
For access to the base llama-7b model, please see Meta's [release](https://ai.facebook.com/blog/large-language-model-llama-meta-ai/) and [request form](https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform).
|
@ -1,60 +0,0 @@
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from peft import PeftConfig, PeftModel
|
||||
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, HfArgumentParser
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
"""
|
||||
The input names representing the Adapter and Base model fine-tuned with PEFT, and the output name representing the
|
||||
merged model.
|
||||
"""
|
||||
|
||||
adapter_model_name: Optional[str] = field(default=None, metadata={"help": "the adapter name"})
|
||||
base_model_name: Optional[str] = field(default=None, metadata={"help": "the base model name"})
|
||||
output_name: Optional[str] = field(default=None, metadata={"help": "the merged model name"})
|
||||
|
||||
|
||||
parser = HfArgumentParser(ScriptArguments)
|
||||
script_args = parser.parse_args_into_dataclasses()[0]
|
||||
assert script_args.adapter_model_name is not None, "please provide the name of the Adapter you would like to merge"
|
||||
assert script_args.base_model_name is not None, "please provide the name of the Base model"
|
||||
assert script_args.output_name is not None, "please provide the output name of the merged model"
|
||||
|
||||
peft_config = PeftConfig.from_pretrained(script_args.adapter_model_name)
|
||||
if peft_config.task_type == "SEQ_CLS":
|
||||
# The sequence classification task is used for the reward model in PPO
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
script_args.base_model_name, num_labels=1, dtype=torch.bfloat16
|
||||
)
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(script_args.base_model_name, return_dict=True, dtype=torch.bfloat16)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(script_args.base_model_name)
|
||||
|
||||
# Load the PEFT model
|
||||
model = PeftModel.from_pretrained(model, script_args.adapter_model_name)
|
||||
model.eval()
|
||||
|
||||
model = model.merge_and_unload()
|
||||
|
||||
model.save_pretrained(f"{script_args.output_name}")
|
||||
tokenizer.save_pretrained(f"{script_args.output_name}")
|
||||
model.push_to_hub(f"{script_args.output_name}", use_temp_dir=False)
|
@ -1,321 +0,0 @@
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import evaluate
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from datasets import load_dataset
|
||||
from peft import LoraConfig, TaskType, get_peft_model
|
||||
from transformers import (
|
||||
AutoModelForSequenceClassification,
|
||||
AutoTokenizer,
|
||||
HfArgumentParser,
|
||||
PreTrainedTokenizerBase,
|
||||
Trainer,
|
||||
TrainerCallback,
|
||||
TrainingArguments,
|
||||
set_seed,
|
||||
)
|
||||
from transformers.utils import PaddingStrategy
|
||||
|
||||
|
||||
# Define and parse arguments.
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
"""
|
||||
These arguments vary depending on how many GPUs you have, what their capacity and features are, and what size model you want to train.
|
||||
"""
|
||||
|
||||
local_rank: Optional[int] = field(default=-1, metadata={"help": "Used for multi-gpu"})
|
||||
resume_from_checkpoint: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "If you want to resume training where it left off."},
|
||||
)
|
||||
deepspeed: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Path to deepspeed config if using deepspeed. You may need this if the model that you want to train doesn't fit on a single GPU."
|
||||
},
|
||||
)
|
||||
per_device_train_batch_size: Optional[int] = field(default=4)
|
||||
per_device_eval_batch_size: Optional[int] = field(default=1)
|
||||
gradient_accumulation_steps: Optional[int] = field(default=1)
|
||||
learning_rate: Optional[float] = field(default=2e-5)
|
||||
weight_decay: Optional[float] = field(default=0.001)
|
||||
model_name: Optional[str] = field(
|
||||
default="gpt2",
|
||||
metadata={
|
||||
"help": "The model that you want to train from the Hugging Face hub. E.g. gpt2, gpt2-xl, bert, etc."
|
||||
},
|
||||
)
|
||||
tokenizer_name: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The tokenizer for your model, if left empty will use the default for your model",
|
||||
},
|
||||
)
|
||||
bf16: Optional[bool] = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "This essentially cuts the training time in half if you want to sacrifice a little precision and have a supported GPU."
|
||||
},
|
||||
)
|
||||
num_train_epochs: Optional[int] = field(
|
||||
default=1,
|
||||
metadata={"help": "The number of training epochs for the reward model."},
|
||||
)
|
||||
train_subset: Optional[int] = field(
|
||||
default=100000,
|
||||
metadata={"help": "The size of the subset of the training data to use"},
|
||||
)
|
||||
eval_subset: Optional[int] = field(
|
||||
default=50000,
|
||||
metadata={"help": "The size of the subset of the eval data to use"},
|
||||
)
|
||||
gradient_checkpointing: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Enables gradient checkpointing."},
|
||||
)
|
||||
optim: Optional[str] = field(
|
||||
default="adamw_hf",
|
||||
metadata={"help": "The optimizer to use."},
|
||||
)
|
||||
lr_scheduler_type: Optional[str] = field(
|
||||
default="linear",
|
||||
metadata={"help": "The lr scheduler"},
|
||||
)
|
||||
max_length: Optional[int] = field(default=512)
|
||||
eval_first_step: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to run eval after the first step"},
|
||||
)
|
||||
seed: Optional[int] = field(
|
||||
default=0, metadata={"help": "Random seed that will be set at the beginning of training."}
|
||||
)
|
||||
|
||||
|
||||
parser = HfArgumentParser(ScriptArguments)
|
||||
script_args = parser.parse_args_into_dataclasses()[0]
|
||||
set_seed(script_args.seed)
|
||||
# Load the human stack-exchange-paired dataset for tuning the reward model.
|
||||
train_dataset = load_dataset(
|
||||
"lvwerra/stack-exchange-paired", data_dir="data/reward", split="train", verification_mode="no_checks"
|
||||
)
|
||||
if script_args.train_subset > 0:
|
||||
train_dataset = train_dataset.select(range(script_args.train_subset))
|
||||
eval_dataset = load_dataset(
|
||||
"lvwerra/stack-exchange-paired", data_dir="data/evaluation", split="train", verification_mode="no_checks"
|
||||
)
|
||||
if script_args.eval_subset > 0:
|
||||
eval_dataset = eval_dataset.select(range(script_args.eval_subset))
|
||||
# Define the training args. Needs to be done before the model is loaded if you are using deepspeed.
|
||||
model_name_split = script_args.model_name.split("/")[-1]
|
||||
output_name = (
|
||||
f"{model_name_split}_peft_stack-exchange-paired_rmts__{script_args.train_subset}_{script_args.learning_rate}"
|
||||
)
|
||||
|
||||
training_args = TrainingArguments(
|
||||
output_dir=output_name,
|
||||
learning_rate=script_args.learning_rate,
|
||||
per_device_train_batch_size=script_args.per_device_train_batch_size,
|
||||
per_device_eval_batch_size=script_args.per_device_eval_batch_size,
|
||||
num_train_epochs=script_args.num_train_epochs,
|
||||
weight_decay=script_args.weight_decay,
|
||||
eval_strategy="steps",
|
||||
eval_steps=500,
|
||||
save_strategy="steps",
|
||||
save_steps=500,
|
||||
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
|
||||
gradient_checkpointing=script_args.gradient_checkpointing,
|
||||
deepspeed=script_args.deepspeed,
|
||||
local_rank=script_args.local_rank,
|
||||
remove_unused_columns=False,
|
||||
label_names=[],
|
||||
bf16=script_args.bf16,
|
||||
logging_strategy="steps",
|
||||
optim=script_args.optim,
|
||||
lr_scheduler_type=script_args.lr_scheduler_type,
|
||||
seed=script_args.seed,
|
||||
)
|
||||
|
||||
|
||||
# Load the value-head model and tokenizer.
|
||||
tokenizer_name = script_args.tokenizer_name if script_args.tokenizer_name is not None else script_args.model_name
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_auth_token=True)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
|
||||
peft_config = LoraConfig(
|
||||
task_type=TaskType.SEQ_CLS,
|
||||
inference_mode=False,
|
||||
r=8,
|
||||
lora_alpha=32,
|
||||
lora_dropout=0.1,
|
||||
)
|
||||
|
||||
model = AutoModelForSequenceClassification.from_pretrained(script_args.model_name, num_labels=1, dtype=torch.bfloat16)
|
||||
model = get_peft_model(model, peft_config)
|
||||
model.print_trainable_parameters()
|
||||
|
||||
# Need to do this for gpt2, because it doesn't have an official pad token.
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
model.config.pad_token_id = tokenizer.eos_token_id
|
||||
model.config.use_cache = not script_args.gradient_checkpointing
|
||||
num_proc = 24 # Can adjust to be higher if you have more processors.
|
||||
original_columns = train_dataset.column_names
|
||||
|
||||
|
||||
# Turn the dataset into pairs of post + summaries, where text_j is the preferred question + answer and text_k is the other.
|
||||
# Then tokenize the dataset.
|
||||
def preprocess_function(examples):
|
||||
new_examples = {
|
||||
"input_ids_j": [],
|
||||
"attention_mask_j": [],
|
||||
"input_ids_k": [],
|
||||
"attention_mask_k": [],
|
||||
}
|
||||
for question, response_j, response_k in zip(examples["question"], examples["response_j"], examples["response_k"]):
|
||||
tokenized_j = tokenizer("Question: " + question + "\n\nAnswer: " + response_j, truncation=True)
|
||||
tokenized_k = tokenizer("Question: " + question + "\n\nAnswer: " + response_k, truncation=True)
|
||||
|
||||
new_examples["input_ids_j"].append(tokenized_j["input_ids"])
|
||||
new_examples["attention_mask_j"].append(tokenized_j["attention_mask"])
|
||||
new_examples["input_ids_k"].append(tokenized_k["input_ids"])
|
||||
new_examples["attention_mask_k"].append(tokenized_k["attention_mask"])
|
||||
|
||||
return new_examples
|
||||
|
||||
|
||||
# preprocess the dataset and filter out QAs that are longer than script_args.max_length
|
||||
train_dataset = train_dataset.map(
|
||||
preprocess_function,
|
||||
batched=True,
|
||||
num_proc=num_proc,
|
||||
remove_columns=original_columns,
|
||||
)
|
||||
train_dataset = train_dataset.filter(
|
||||
lambda x: len(x["input_ids_j"]) <= script_args.max_length and len(x["input_ids_k"]) <= script_args.max_length,
|
||||
num_proc=num_proc,
|
||||
)
|
||||
|
||||
eval_dataset = eval_dataset.map(
|
||||
preprocess_function,
|
||||
batched=True,
|
||||
num_proc=num_proc,
|
||||
remove_columns=original_columns,
|
||||
)
|
||||
eval_dataset = eval_dataset.filter(
|
||||
lambda x: len(x["input_ids_j"]) <= script_args.max_length and len(x["input_ids_k"]) <= script_args.max_length,
|
||||
num_proc=num_proc,
|
||||
)
|
||||
|
||||
|
||||
# We need to define a special data collator that batches the data in our j vs k format.
|
||||
@dataclass
|
||||
class RewardDataCollatorWithPadding:
|
||||
tokenizer: PreTrainedTokenizerBase
|
||||
padding: Union[bool, str, PaddingStrategy] = True
|
||||
pad_to_multiple_of: Optional[int] = None
|
||||
return_tensors: str = "pt"
|
||||
|
||||
def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:
|
||||
features_j = []
|
||||
features_k = []
|
||||
for feature in features:
|
||||
features_j.append(
|
||||
{
|
||||
"input_ids": feature["input_ids_j"],
|
||||
"attention_mask": feature["attention_mask_j"],
|
||||
}
|
||||
)
|
||||
features_k.append(
|
||||
{
|
||||
"input_ids": feature["input_ids_k"],
|
||||
"attention_mask": feature["attention_mask_k"],
|
||||
}
|
||||
)
|
||||
batch_j = self.tokenizer.pad(
|
||||
features_j,
|
||||
padding=self.padding,
|
||||
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
return_tensors=self.return_tensors,
|
||||
)
|
||||
batch_k = self.tokenizer.pad(
|
||||
features_k,
|
||||
padding=self.padding,
|
||||
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
return_tensors=self.return_tensors,
|
||||
)
|
||||
batch = {
|
||||
"input_ids_j": batch_j["input_ids"],
|
||||
"attention_mask_j": batch_j["attention_mask"],
|
||||
"input_ids_k": batch_k["input_ids"],
|
||||
"attention_mask_k": batch_k["attention_mask"],
|
||||
"return_loss": True,
|
||||
}
|
||||
return batch
|
||||
|
||||
|
||||
# Define the metric that we'll use for validation.
|
||||
accuracy = evaluate.load("accuracy")
|
||||
|
||||
|
||||
def compute_metrics(eval_pred):
|
||||
predictions, _ = eval_pred
|
||||
# Here, predictions is rewards_j and rewards_k.
|
||||
# We want to see how much of the time rewards_j > rewards_k.
|
||||
predictions = np.argmax(predictions, axis=0)
|
||||
labels = np.zeros(predictions.shape)
|
||||
return accuracy.compute(predictions=predictions, references=labels)
|
||||
|
||||
|
||||
class RewardTrainer(Trainer):
|
||||
# Define how to compute the reward loss. We use the InstructGPT pairwise logloss: https://huggingface.co/papers/2203.02155
|
||||
def compute_loss(self, model, inputs, return_outputs=False):
|
||||
rewards_j = model(input_ids=inputs["input_ids_j"], attention_mask=inputs["attention_mask_j"])[0]
|
||||
rewards_k = model(input_ids=inputs["input_ids_k"], attention_mask=inputs["attention_mask_k"])[0]
|
||||
loss = -nn.functional.logsigmoid(rewards_j - rewards_k).mean()
|
||||
if return_outputs:
|
||||
return loss, {"rewards_j": rewards_j, "rewards_k": rewards_k}
|
||||
return loss
|
||||
|
||||
|
||||
# Train the model, woohoo.
|
||||
trainer = RewardTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
compute_metrics=compute_metrics,
|
||||
data_collator=RewardDataCollatorWithPadding(tokenizer=tokenizer),
|
||||
)
|
||||
|
||||
|
||||
if script_args.eval_first_step:
|
||||
|
||||
class EvaluateFirstStepCallback(TrainerCallback):
|
||||
def on_step_end(self, args, state, control, **kwargs):
|
||||
if state.global_step == 1:
|
||||
control.should_evaluate = True
|
||||
|
||||
trainer.add_callback(EvaluateFirstStepCallback())
|
||||
|
||||
trainer.train(script_args.resume_from_checkpoint)
|
||||
|
||||
print("Saving last checkpoint of the model")
|
||||
model.save_pretrained(output_name + "_peft_last_checkpoint")
|
@ -1,270 +0,0 @@
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from datasets import load_dataset
|
||||
from peft import LoraConfig
|
||||
from tqdm import tqdm
|
||||
from transformers import Adafactor, AutoTokenizer, HfArgumentParser, pipeline, set_seed
|
||||
|
||||
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer
|
||||
from trl.core import LengthSampler
|
||||
|
||||
|
||||
tqdm.pandas()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
"""
|
||||
The name of the Casual LM model we wish to fine-tune with PPO
|
||||
"""
|
||||
|
||||
# NOTE: gpt2 models use Conv1D instead of Linear layers which are not yet supported in 8 bit mode
|
||||
# models like gpt-neo* models are more suitable.
|
||||
model_name: Optional[str] = field(default="", metadata={"help": "the model name"})
|
||||
tokenizer_name: Optional[str] = field(default="", metadata={"help": "the tokenizer name"})
|
||||
reward_model_name: Optional[str] = field(default="", metadata={"help": "the reward model name"})
|
||||
log_with: Optional[str] = field(default=None, metadata={"help": "use 'wandb' to log with wandb"})
|
||||
learning_rate: Optional[float] = field(default=1.41e-5, metadata={"help": "the learning rate"})
|
||||
output_max_length: Optional[int] = field(default=128, metadata={"help": "maximum length for generation"})
|
||||
mini_batch_size: Optional[int] = field(default=1, metadata={"help": "the PPO minibatch size"})
|
||||
batch_size: Optional[int] = field(default=32, metadata={"help": "the batch size"})
|
||||
ppo_epochs: Optional[int] = field(default=4, metadata={"help": "the number of ppo epochs"})
|
||||
gradient_accumulation_steps: Optional[int] = field(
|
||||
default=4, metadata={"help": "the number of gradient accumulation steps"}
|
||||
)
|
||||
adafactor: Optional[bool] = field(default=False, metadata={"help": "whether to use the adafactor optimizer"})
|
||||
early_stopping: Optional[bool] = field(default=False, metadata={"help": "whether to early stop"})
|
||||
target_kl: Optional[float] = field(default=0.1, metadata={"help": "kl target for early stopping"})
|
||||
reward_baseline: Optional[float] = field(
|
||||
default=0.0,
|
||||
metadata={"help": "a baseline value that is subtracted from the reward"},
|
||||
)
|
||||
batched_gen: Optional[bool] = field(default=False, metadata={"help": "whether to use the batched text gen"})
|
||||
save_freq: Optional[int] = field(default=None, metadata={"help": "n steps to save the model"})
|
||||
output_dir: Optional[str] = field(default="runs/", metadata={"help": "n steps to save the model"})
|
||||
seed: Optional[int] = field(default=0, metadata={"help": "the seed"})
|
||||
steps: Optional[int] = field(default=20000, metadata={"help": "number of epochs"})
|
||||
init_kl_coef: Optional[float] = field(
|
||||
default=0.2,
|
||||
metadata={"help": "Initial KL penalty coefficient (used for adaptive and linear control)"},
|
||||
)
|
||||
|
||||
adap_kl_ctrl: Optional[bool] = field(default=True, metadata={"help": "Use adaptive KL control, otherwise linear"})
|
||||
load_in_8bit: Optional[bool] = field(default=True, metadata={"help": "whether to load the model in 8bit"})
|
||||
|
||||
|
||||
parser = HfArgumentParser(ScriptArguments)
|
||||
script_args: ScriptArguments = parser.parse_args_into_dataclasses()[0]
|
||||
reward_model_name = script_args.reward_model_name
|
||||
dataset_name = "lvwerra/stack-exchange-paired"
|
||||
config = PPOConfig(
|
||||
steps=script_args.steps,
|
||||
model_name=script_args.model_name,
|
||||
learning_rate=script_args.learning_rate,
|
||||
log_with=script_args.log_with,
|
||||
batch_size=script_args.batch_size,
|
||||
mini_batch_size=script_args.mini_batch_size,
|
||||
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
|
||||
optimize_device_cache=True,
|
||||
early_stopping=script_args.early_stopping,
|
||||
target_kl=script_args.target_kl,
|
||||
ppo_epochs=script_args.ppo_epochs,
|
||||
seed=script_args.seed,
|
||||
init_kl_coef=script_args.init_kl_coef,
|
||||
adap_kl_ctrl=script_args.adap_kl_ctrl,
|
||||
)
|
||||
|
||||
train_dataset = load_dataset(
|
||||
"lvwerra/stack-exchange-paired", data_dir="data/rl", split="train", verification_mode="no_checks"
|
||||
)
|
||||
train_dataset = train_dataset.select(range(100000))
|
||||
original_columns = train_dataset.column_names
|
||||
|
||||
# We then define the arguments to pass to the sentiment analysis pipeline.
|
||||
# We set `return_all_scores` to True to get the sentiment score for each token.
|
||||
sent_kwargs = {
|
||||
"return_all_scores": True,
|
||||
"function_to_apply": "none",
|
||||
"batch_size": 16,
|
||||
"truncation": True,
|
||||
}
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(script_args.tokenizer_name)
|
||||
# GPT-2 tokenizer has a pad token, but it is not eos_token by default. We need to set it to eos_token.
|
||||
# only for this model.
|
||||
|
||||
if getattr(tokenizer, "pad_token", None) is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
|
||||
# Below is an example function to build the dataset. In our case, we use the IMDB dataset
|
||||
# from the `datasets` library. One should customize this function to train the model on
|
||||
# its own dataset.
|
||||
def build_dataset(
|
||||
tokenizer,
|
||||
dataset_name="lvwerra/stack-exchange-paired",
|
||||
):
|
||||
"""
|
||||
Build dataset for training. This builds the dataset from `load_dataset`, one should
|
||||
customize this function to train the model on its own dataset.
|
||||
|
||||
Args:
|
||||
tokenizer (`transformers.PreTrainedTokenizer`):
|
||||
The tokenizer used for the model.
|
||||
dataset_name (`str`):
|
||||
The name of the dataset to be loaded.
|
||||
|
||||
Returns:
|
||||
dataloader (`torch.utils.data.DataLoader`):
|
||||
The dataloader for the dataset.
|
||||
"""
|
||||
|
||||
num_proc = 24
|
||||
|
||||
def preprocess_function(examples):
|
||||
new_examples = {
|
||||
"query": [],
|
||||
"input_ids": [],
|
||||
}
|
||||
for question in examples["question"]:
|
||||
query = "Question: " + question + "\n\nAnswer: "
|
||||
tokenized_question = tokenizer(query, truncation=True)
|
||||
new_examples["query"].append(query)
|
||||
new_examples["input_ids"].append(tokenized_question["input_ids"])
|
||||
|
||||
return new_examples
|
||||
|
||||
ds = train_dataset.map(
|
||||
preprocess_function,
|
||||
batched=True,
|
||||
num_proc=num_proc,
|
||||
remove_columns=original_columns,
|
||||
)
|
||||
ds = ds.filter(lambda x: len(x["input_ids"]) < 512, batched=False, num_proc=num_proc)
|
||||
|
||||
ds.set_format(type="torch")
|
||||
return ds
|
||||
|
||||
|
||||
# We retrieve the dataloader by calling the `build_dataset` function.
|
||||
dataset = build_dataset(tokenizer)
|
||||
|
||||
|
||||
def collator(data):
|
||||
return {key: [d[key] for d in data] for key in data[0]}
|
||||
|
||||
|
||||
# set seed before initializing value head for deterministic eval
|
||||
set_seed(config.seed)
|
||||
|
||||
# Now let's build the model, the reference model, and the tokenizer.
|
||||
current_device = Accelerator().local_process_index
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
lora_dropout=0.05,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
config.model_name,
|
||||
load_in_8bit=script_args.load_in_8bit,
|
||||
device_map={"": current_device},
|
||||
peft_config=lora_config,
|
||||
)
|
||||
|
||||
optimizer = None
|
||||
if script_args.adafactor:
|
||||
optimizer = Adafactor(
|
||||
filter(lambda p: p.requires_grad, model.parameters()),
|
||||
scale_parameter=False,
|
||||
relative_step=False,
|
||||
warmup_init=False,
|
||||
lr=config.learning_rate,
|
||||
)
|
||||
# We then build the PPOTrainer, passing the model, the reference model, the tokenizer
|
||||
ppo_trainer = PPOTrainer(
|
||||
config,
|
||||
model,
|
||||
ref_model=None,
|
||||
tokenizer=tokenizer,
|
||||
dataset=dataset,
|
||||
data_collator=collator,
|
||||
optimizer=optimizer,
|
||||
)
|
||||
|
||||
# We then build the sentiment analysis pipeline using our reward model, passing the
|
||||
# model name and the sentiment analysis pipeline arguments. Let's also make sure to
|
||||
# set the device to the same device as the PPOTrainer.
|
||||
device = ppo_trainer.accelerator.device
|
||||
if ppo_trainer.accelerator.num_processes == 1:
|
||||
device = 0 if torch.cuda.is_available() else "cpu" # to avoid a ` pipeline` bug
|
||||
sentiment_pipe = pipeline(
|
||||
"sentiment-analysis",
|
||||
model=reward_model_name,
|
||||
device_map={"": current_device},
|
||||
model_kwargs={"load_in_8bit": script_args.load_in_8bit},
|
||||
tokenizer=tokenizer,
|
||||
return_token_type_ids=False,
|
||||
)
|
||||
|
||||
if sentiment_pipe.model.config.pad_token_id is None:
|
||||
sentiment_pipe.model.config.pad_token_id = sentiment_pipe.model.config.eos_token_id
|
||||
# We then define the arguments to pass to the `generate` function. These arguments
|
||||
# are passed to the `generate` function of the PPOTrainer, which is a wrapper around
|
||||
# the `generate` function of the trained model.
|
||||
generation_kwargs = {
|
||||
# "min_length": -1,
|
||||
"top_k": 0.0,
|
||||
"top_p": 1.0,
|
||||
"do_sample": True,
|
||||
"pad_token_id": tokenizer.pad_token_id,
|
||||
"eos_token_id": 100_000,
|
||||
}
|
||||
output_min_length = 32
|
||||
output_max_length = script_args.output_max_length
|
||||
output_length_sampler = LengthSampler(output_min_length, output_max_length)
|
||||
|
||||
for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
|
||||
if epoch >= config.total_ppo_epochs:
|
||||
break
|
||||
|
||||
question_tensors = batch["input_ids"]
|
||||
|
||||
response_tensors = ppo_trainer.generate(
|
||||
question_tensors,
|
||||
return_prompt=False,
|
||||
length_sampler=output_length_sampler,
|
||||
**generation_kwargs,
|
||||
)
|
||||
batch["response"] = tokenizer.batch_decode(response_tensors, skip_special_tokens=True)
|
||||
|
||||
# Compute reward score (using the sentiment analysis pipeline)
|
||||
texts = [q + r for q, r in zip(batch["query"], batch["response"])]
|
||||
pipe_outputs = sentiment_pipe(texts, **sent_kwargs)
|
||||
rewards = [torch.tensor(output[0]["score"] - script_args.reward_baseline) for output in pipe_outputs]
|
||||
|
||||
# Run PPO step
|
||||
stats = ppo_trainer.step(question_tensors, response_tensors, rewards)
|
||||
ppo_trainer.log_stats(stats, batch, rewards)
|
||||
|
||||
if script_args.save_freq and epoch and epoch % script_args.save_freq == 0:
|
||||
ppo_trainer.save_pretrained(script_args.output_dir + f"step_{epoch}")
|
@ -1,222 +0,0 @@
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
from accelerate import Accelerator
|
||||
from datasets import load_dataset
|
||||
from peft import LoraConfig
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, logging, set_seed
|
||||
|
||||
from trl import SFTTrainer
|
||||
from trl.trainer import ConstantLengthDataset
|
||||
|
||||
|
||||
"""
|
||||
Fine-Tune Llama-7b on SE paired dataset
|
||||
"""
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model_path", type=str, default="")
|
||||
parser.add_argument("--dataset_name", type=str, default="lvwerra/stack-exchange-paired")
|
||||
parser.add_argument("--subset", type=str, default="data/finetune")
|
||||
parser.add_argument("--split", type=str, default="train")
|
||||
parser.add_argument("--size_valid_set", type=int, default=4000)
|
||||
parser.add_argument("--streaming", action="store_true")
|
||||
parser.add_argument("--shuffle_buffer", type=int, default=5000)
|
||||
|
||||
parser.add_argument("--seq_length", type=int, default=1024)
|
||||
parser.add_argument("--max_steps", type=int, default=10000)
|
||||
parser.add_argument("--batch_size", type=int, default=4)
|
||||
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
|
||||
parser.add_argument("--eos_token_id", type=int, default=49152)
|
||||
|
||||
parser.add_argument("--learning_rate", type=float, default=1e-4)
|
||||
parser.add_argument("--lr_scheduler_type", type=str, default="cosine")
|
||||
parser.add_argument("--num_warmup_steps", type=int, default=100)
|
||||
parser.add_argument("--weight_decay", type=float, default=0.05)
|
||||
|
||||
parser.add_argument("--local_rank", type=int, default=0)
|
||||
parser.add_argument("--fp16", action="store_true", default=False)
|
||||
parser.add_argument("--bf16", action="store_true", default=False)
|
||||
parser.add_argument("--gradient_checkpointing", action="store_true", default=False)
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--num_workers", type=int, default=None)
|
||||
parser.add_argument("--output_dir", type=str, default="./checkpoints")
|
||||
parser.add_argument("--log_freq", default=1, type=int)
|
||||
parser.add_argument("--eval_freq", default=1000, type=int)
|
||||
parser.add_argument("--save_freq", default=1000, type=int)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def chars_token_ratio(dataset, tokenizer, nb_examples=400):
|
||||
"""
|
||||
Estimate the average number of characters per token in the dataset.
|
||||
"""
|
||||
total_characters, total_tokens = 0, 0
|
||||
for _, example in tqdm(zip(range(nb_examples), iter(dataset)), total=nb_examples):
|
||||
text = prepare_sample_text(example)
|
||||
total_characters += len(text)
|
||||
if tokenizer.is_fast:
|
||||
total_tokens += len(tokenizer(text).tokens())
|
||||
else:
|
||||
total_tokens += len(tokenizer.tokenize(text))
|
||||
|
||||
return total_characters / total_tokens
|
||||
|
||||
|
||||
def print_trainable_parameters(model):
|
||||
"""
|
||||
Prints the number of trainable parameters in the model.
|
||||
"""
|
||||
trainable_params = 0
|
||||
all_param = 0
|
||||
for _, param in model.named_parameters():
|
||||
all_param += param.numel()
|
||||
if param.requires_grad:
|
||||
trainable_params += param.numel()
|
||||
print(
|
||||
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
|
||||
)
|
||||
|
||||
|
||||
def prepare_sample_text(example):
|
||||
"""Prepare the text from a sample of the dataset."""
|
||||
text = f"Question: {example['question']}\n\nAnswer: {example['response_j']}"
|
||||
return text
|
||||
|
||||
|
||||
def create_datasets(tokenizer, args):
|
||||
dataset = load_dataset(
|
||||
args.dataset_name,
|
||||
data_dir=args.subset,
|
||||
split=args.split,
|
||||
use_auth_token=True,
|
||||
num_proc=args.num_workers if not args.streaming else None,
|
||||
streaming=args.streaming,
|
||||
)
|
||||
if args.streaming:
|
||||
print("Loading the dataset in streaming mode")
|
||||
valid_data = dataset.take(args.size_valid_set)
|
||||
train_data = dataset.skip(args.size_valid_set)
|
||||
train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=args.seed)
|
||||
else:
|
||||
dataset = dataset.train_test_split(test_size=0.005, seed=args.seed)
|
||||
train_data = dataset["train"]
|
||||
valid_data = dataset["test"]
|
||||
print(f"Size of the train set: {len(train_data)}. Size of the validation set: {len(valid_data)}")
|
||||
|
||||
chars_per_token = chars_token_ratio(train_data, tokenizer)
|
||||
print(f"The character to token ratio of the dataset is: {chars_per_token:.2f}")
|
||||
|
||||
train_dataset = ConstantLengthDataset(
|
||||
tokenizer,
|
||||
train_data,
|
||||
formatting_func=prepare_sample_text,
|
||||
infinite=True,
|
||||
seq_length=args.seq_length,
|
||||
chars_per_token=chars_per_token,
|
||||
)
|
||||
valid_dataset = ConstantLengthDataset(
|
||||
tokenizer,
|
||||
valid_data,
|
||||
formatting_func=prepare_sample_text,
|
||||
infinite=False,
|
||||
seq_length=args.seq_length,
|
||||
chars_per_token=chars_per_token,
|
||||
)
|
||||
return train_dataset, valid_dataset
|
||||
|
||||
|
||||
def run_training(args, train_data, val_data):
|
||||
print("Loading the model")
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
lora_dropout=0.05,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
|
||||
train_data.start_iteration = 0
|
||||
|
||||
print("Starting main loop")
|
||||
|
||||
training_args = TrainingArguments(
|
||||
output_dir=args.output_dir,
|
||||
dataloader_drop_last=True,
|
||||
eval_strategy="steps",
|
||||
max_steps=args.max_steps,
|
||||
eval_steps=args.eval_freq,
|
||||
save_steps=args.save_freq,
|
||||
logging_steps=args.log_freq,
|
||||
per_device_train_batch_size=args.batch_size,
|
||||
per_device_eval_batch_size=args.batch_size,
|
||||
learning_rate=args.learning_rate,
|
||||
lr_scheduler_type=args.lr_scheduler_type,
|
||||
warmup_steps=args.num_warmup_steps,
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
gradient_checkpointing=args.gradient_checkpointing,
|
||||
fp16=args.fp16,
|
||||
bf16=args.bf16,
|
||||
weight_decay=args.weight_decay,
|
||||
run_name="llama-7b-finetuned",
|
||||
report_to="wandb",
|
||||
ddp_find_unused_parameters=False,
|
||||
)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.model_path, load_in_8bit=True, device_map={"": Accelerator().process_index}
|
||||
)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_data,
|
||||
eval_dataset=val_data,
|
||||
peft_config=lora_config,
|
||||
packing=True,
|
||||
)
|
||||
|
||||
print_trainable_parameters(trainer.model)
|
||||
|
||||
print("Training...")
|
||||
trainer.train()
|
||||
|
||||
print("Saving last checkpoint of the model")
|
||||
trainer.model.save_pretrained(os.path.join(args.output_dir, "final_checkpoint/"))
|
||||
|
||||
|
||||
def main(args):
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
|
||||
train_dataset, eval_dataset = create_datasets(tokenizer, args)
|
||||
run_training(args, train_dataset, eval_dataset)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
assert args.model_path != "", "Please provide the llama model path"
|
||||
|
||||
set_seed(args.seed)
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
logging.set_verbosity_error()
|
||||
|
||||
main(args)
|
@ -1,75 +0,0 @@
|
||||
# DPO pipeline for the creation of StackLlaMa 2: a Stack exchange llama-v2-7b model
|
||||
|
||||
## Prerequisites
|
||||
|
||||
Install all the dependencies in the `requirements.txt`:
|
||||
|
||||
```
|
||||
$ pip install -U -r requirements.txt
|
||||
```
|
||||
|
||||
Since we will use `accelerate` for training, make sure to run:
|
||||
```
|
||||
$ accelerate config
|
||||
```
|
||||
|
||||
## Training
|
||||
|
||||
There were two main steps to the DPO training process:
|
||||
1. Supervised fine-tuning of the base llama-v2-7b model to create llama-v2-7b-se:
|
||||
|
||||
```
|
||||
accelerate launch examples/research_projects/stack_llama_2/scripts/sft_llama2.py \
|
||||
--output_dir="./sft" \
|
||||
--max_steps=500 \
|
||||
--save_steps=10 \
|
||||
--per_device_train_batch_size=4 \
|
||||
--per_device_eval_batch_size=1 \
|
||||
--gradient_accumulation_steps=2 \
|
||||
--gradient_checkpointing=False \
|
||||
--group_by_length=False \
|
||||
--learning_rate=1e-4 \
|
||||
--lr_scheduler_type="cosine" \
|
||||
--warmup_steps=100 \
|
||||
--weight_decay=0.05 \
|
||||
--optim="paged_adamw_32bit" \
|
||||
--bf16=True \
|
||||
--remove_unused_columns=False \
|
||||
--run_name="sft_llama2" \
|
||||
--report_to="wandb"
|
||||
```
|
||||
1. Run the DPO trainer using the model saved by the previous step:
|
||||
```
|
||||
accelerate launch examples/research_projects/stack_llama_2/scripts/dpo_llama2.py \
|
||||
--model_name_or_path="sft/final_checkpoint" \
|
||||
--output_dir="dpo"
|
||||
```
|
||||
|
||||
|
||||
## Merging the adaptors
|
||||
|
||||
To merge the adaptors into the base model we can use the `merge_peft_adapter.py` helper script that comes with TRL:
|
||||
|
||||
```
|
||||
python examples/research_projects/stack_llama/scripts/merge_peft_adapter.py --base_model_name="meta-llama/Llama-2-7b-hf" --adapter_model_name="dpo/final_checkpoint/" --output_name="stack-llama-2"
|
||||
```
|
||||
|
||||
which will also push the model to your HuggingFace hub account.
|
||||
|
||||
## Running the model
|
||||
|
||||
We can load the DPO-trained LoRA adaptors which were saved by the DPO training step and load them via:
|
||||
|
||||
```py
|
||||
from peft import AutoPeftModelForCausalLM
|
||||
|
||||
|
||||
model = AutoPeftModelForCausalLM.from_pretrained(
|
||||
"dpo/final_checkpoint",
|
||||
low_cpu_mem_usage=True,
|
||||
dtype=torch.float16,
|
||||
load_in_4bit=True,
|
||||
)
|
||||
|
||||
model.generate(...)
|
||||
```
|
@ -1,252 +0,0 @@
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# 0. imports
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from datasets import Dataset, load_dataset
|
||||
from peft import LoraConfig
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, set_seed
|
||||
|
||||
from trl import DPOConfig, DPOTrainer
|
||||
|
||||
|
||||
# Define and parse arguments.
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
"""
|
||||
The arguments for the DPO training script.
|
||||
"""
|
||||
|
||||
# data parameters
|
||||
beta: Optional[float] = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"})
|
||||
|
||||
# training parameters
|
||||
model_name_or_path: Optional[str] = field(
|
||||
default="../sft/results/final_checkpoint",
|
||||
metadata={"help": "the location of the SFT model name or path"},
|
||||
)
|
||||
learning_rate: Optional[float] = field(default=5e-4, metadata={"help": "optimizer learning rate"})
|
||||
lr_scheduler_type: Optional[str] = field(default="cosine", metadata={"help": "the lr scheduler type"})
|
||||
warmup_steps: Optional[int] = field(default=100, metadata={"help": "the number of warmup steps"})
|
||||
weight_decay: Optional[float] = field(default=0.05, metadata={"help": "the weight decay"})
|
||||
optimizer_type: Optional[str] = field(default="paged_adamw_32bit", metadata={"help": "the optimizer type"})
|
||||
|
||||
per_device_train_batch_size: Optional[int] = field(default=4, metadata={"help": "train batch size per device"})
|
||||
per_device_eval_batch_size: Optional[int] = field(default=1, metadata={"help": "eval batch size per device"})
|
||||
gradient_accumulation_steps: Optional[int] = field(
|
||||
default=4, metadata={"help": "the number of gradient accumulation steps"}
|
||||
)
|
||||
gradient_checkpointing: Optional[bool] = field(
|
||||
default=True, metadata={"help": "whether to use gradient checkpointing"}
|
||||
)
|
||||
|
||||
gradient_checkpointing_use_reentrant: Optional[bool] = field(
|
||||
default=False, metadata={"help": "whether to use reentrant for gradient checkpointing"}
|
||||
)
|
||||
|
||||
lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"})
|
||||
lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "the lora dropout parameter"})
|
||||
lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"})
|
||||
|
||||
max_prompt_length: Optional[int] = field(default=512, metadata={"help": "the maximum prompt length"})
|
||||
max_length: Optional[int] = field(default=1024, metadata={"help": "the maximum sequence length"})
|
||||
max_steps: Optional[int] = field(default=1000, metadata={"help": "max number of training steps"})
|
||||
logging_steps: Optional[int] = field(default=10, metadata={"help": "the logging frequency"})
|
||||
save_steps: Optional[int] = field(default=100, metadata={"help": "the saving frequency"})
|
||||
eval_steps: Optional[int] = field(default=100, metadata={"help": "the evaluation frequency"})
|
||||
|
||||
output_dir: Optional[str] = field(default="./results", metadata={"help": "the output directory"})
|
||||
log_freq: Optional[int] = field(default=1, metadata={"help": "the logging frequency"})
|
||||
load_in_4bit: Optional[bool] = field(default=True, metadata={"help": "whether to load the model in 4bit"})
|
||||
model_dtype: Optional[str] = field(
|
||||
default="float16", metadata={"help": "model_dtype[float16, bfloat16, float] for loading."}
|
||||
)
|
||||
|
||||
# instrumentation
|
||||
report_to: Optional[str] = field(
|
||||
default="wandb",
|
||||
metadata={
|
||||
"help": 'The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,'
|
||||
'`"comet_ml"`, `"mlflow"`, `"neptune"`, `"tensorboard"`,`"clearml"` and `"wandb"`. '
|
||||
'Use `"all"` to report to all integrations installed, `"none"` for no integrations.'
|
||||
},
|
||||
)
|
||||
# debug argument for distributed training
|
||||
ignore_bias_buffers: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See"
|
||||
"https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992"
|
||||
},
|
||||
)
|
||||
seed: Optional[int] = field(
|
||||
default=0, metadata={"help": "Random seed that will be set at the beginning of training."}
|
||||
)
|
||||
|
||||
|
||||
def get_stack_exchange_paired(
|
||||
data_dir: str = "data/rl",
|
||||
cache_dir: Optional[str] = None,
|
||||
num_proc=24,
|
||||
) -> Dataset:
|
||||
"""Load the stack-exchange-paired dataset from Hugging Face and convert it to the necessary format.
|
||||
|
||||
The dataset is converted to a dictionary with the following structure:
|
||||
{
|
||||
'prompt': list[str],
|
||||
'chosen': list[str],
|
||||
'rejected': list[str],
|
||||
}
|
||||
|
||||
Prompts are structured as follows:
|
||||
"Question: " + <prompt> + "\n\nAnswer: "
|
||||
"""
|
||||
dataset = load_dataset(
|
||||
"lvwerra/stack-exchange-paired",
|
||||
split="train",
|
||||
cache_dir=cache_dir,
|
||||
data_dir=data_dir,
|
||||
verification_mode="no_checks",
|
||||
)
|
||||
original_columns = dataset.column_names
|
||||
|
||||
def return_prompt_and_responses(samples) -> dict[str, str]:
|
||||
return {
|
||||
"prompt": ["Question: " + question + "\n\nAnswer: " for question in samples["question"]],
|
||||
"chosen": samples["response_j"],
|
||||
"rejected": samples["response_k"],
|
||||
}
|
||||
|
||||
return dataset.map(
|
||||
return_prompt_and_responses,
|
||||
batched=True,
|
||||
num_proc=num_proc,
|
||||
remove_columns=original_columns,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = HfArgumentParser(ScriptArguments)
|
||||
script_args = parser.parse_args_into_dataclasses()[0]
|
||||
|
||||
set_seed(script_args.seed)
|
||||
|
||||
# 1. load a pretrained model
|
||||
dtype = torch.float
|
||||
if script_args.model_dtype == "float16":
|
||||
dtype = torch.float16
|
||||
elif script_args.model_dtype == "bfloat16":
|
||||
dtype = torch.bfloat16
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
script_args.model_name_or_path,
|
||||
low_cpu_mem_usage=True,
|
||||
dtype=dtype,
|
||||
load_in_4bit=script_args.load_in_4bit,
|
||||
device_map={"": Accelerator().local_process_index},
|
||||
)
|
||||
model.config.use_cache = False
|
||||
|
||||
if script_args.ignore_bias_buffers:
|
||||
# torch distributed hack
|
||||
model._ddp_params_and_buffers_to_ignore = [
|
||||
name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool
|
||||
]
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
# 2. Load the Stack-exchange paired dataset
|
||||
train_dataset = get_stack_exchange_paired(data_dir="data/rl")
|
||||
train_dataset = train_dataset.filter(
|
||||
lambda x: len(x["prompt"]) + len(x["chosen"]) <= script_args.max_length
|
||||
and len(x["prompt"]) + len(x["rejected"]) <= script_args.max_length,
|
||||
num_proc=script_args.num_proc,
|
||||
)
|
||||
|
||||
# 3. Load evaluation dataset
|
||||
eval_dataset = get_stack_exchange_paired(data_dir="data/evaluation")
|
||||
eval_dataset = eval_dataset.filter(
|
||||
lambda x: len(x["prompt"]) + len(x["chosen"]) <= script_args.max_length
|
||||
and len(x["prompt"]) + len(x["rejected"]) <= script_args.max_length,
|
||||
num_proc=script_args.num_proc,
|
||||
)
|
||||
|
||||
# 4. initialize training arguments:
|
||||
training_args = DPOConfig(
|
||||
per_device_train_batch_size=script_args.per_device_train_batch_size,
|
||||
per_device_eval_batch_size=script_args.per_device_eval_batch_size,
|
||||
max_steps=script_args.max_steps,
|
||||
logging_steps=script_args.logging_steps,
|
||||
save_steps=script_args.save_steps,
|
||||
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
|
||||
gradient_checkpointing=script_args.gradient_checkpointing,
|
||||
learning_rate=script_args.learning_rate,
|
||||
eval_strategy="steps",
|
||||
eval_steps=script_args.eval_steps,
|
||||
output_dir=script_args.output_dir,
|
||||
report_to=script_args.report_to,
|
||||
lr_scheduler_type=script_args.lr_scheduler_type,
|
||||
warmup_steps=script_args.warmup_steps,
|
||||
optim=script_args.optimizer_type,
|
||||
bf16=True,
|
||||
remove_unused_columns=False,
|
||||
run_name="dpo_llama2",
|
||||
gradient_checkpointing_kwargs=dict(use_reentrant=script_args.gradient_checkpointing_use_reentrant),
|
||||
seed=script_args.seed,
|
||||
)
|
||||
|
||||
peft_config = LoraConfig(
|
||||
r=script_args.lora_r,
|
||||
lora_alpha=script_args.lora_alpha,
|
||||
lora_dropout=script_args.lora_dropout,
|
||||
target_modules=[
|
||||
"q_proj",
|
||||
"v_proj",
|
||||
"k_proj",
|
||||
"out_proj",
|
||||
"fc_in",
|
||||
"fc_out",
|
||||
"wte",
|
||||
],
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
|
||||
# 5. initialize the DPO trainer
|
||||
dpo_trainer = DPOTrainer(
|
||||
model,
|
||||
ref_model=None,
|
||||
args=training_args,
|
||||
beta=script_args.beta,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
processing_class=tokenizer,
|
||||
peft_config=peft_config,
|
||||
max_prompt_length=script_args.max_prompt_length,
|
||||
max_length=script_args.max_length,
|
||||
)
|
||||
|
||||
# 6. train
|
||||
dpo_trainer.train()
|
||||
dpo_trainer.save_model(script_args.output_dir)
|
||||
|
||||
# 7. save
|
||||
output_dir = os.path.join(script_args.output_dir, "final_checkpoint")
|
||||
dpo_trainer.model.save_pretrained(output_dir)
|
@ -1,7 +0,0 @@
|
||||
transformers
|
||||
trl
|
||||
peft
|
||||
accelerate
|
||||
datasets
|
||||
bitsandbytes
|
||||
wandb
|
@ -1,212 +0,0 @@
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Fine-Tune Llama2-7b on SE paired dataset
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from datasets import load_dataset
|
||||
from peft import AutoPeftModelForCausalLM, LoraConfig
|
||||
from tqdm import tqdm
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
BitsAndBytesConfig,
|
||||
HfArgumentParser,
|
||||
is_torch_npu_available,
|
||||
is_torch_xpu_available,
|
||||
set_seed,
|
||||
)
|
||||
|
||||
from trl import SFTConfig, SFTTrainer
|
||||
from trl.trainer import ConstantLengthDataset
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
model_name: Optional[str] = field(default="meta-llama/Llama-2-7b-hf", metadata={"help": "the model name"})
|
||||
dataset_name: Optional[str] = field(default="lvwerra/stack-exchange-paired", metadata={"help": "the dataset name"})
|
||||
subset: Optional[str] = field(default="data/finetune", metadata={"help": "the subset to use"})
|
||||
split: Optional[str] = field(default="train", metadata={"help": "the split to use"})
|
||||
size_valid_set: Optional[int] = field(default=4000, metadata={"help": "the size of the validation set"})
|
||||
streaming: Optional[bool] = field(default=True, metadata={"help": "whether to stream the dataset"})
|
||||
shuffle_buffer: Optional[int] = field(default=5000, metadata={"help": "the shuffle buffer size"})
|
||||
seq_length: Optional[int] = field(default=1024, metadata={"help": "the sequence length"})
|
||||
num_workers: Optional[int] = field(default=4, metadata={"help": "the number of workers"})
|
||||
use_bnb: Optional[bool] = field(default=True, metadata={"help": "whether to use BitsAndBytes"})
|
||||
|
||||
# LoraConfig
|
||||
lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"})
|
||||
lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "the lora dropout parameter"})
|
||||
lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"})
|
||||
|
||||
|
||||
parser = HfArgumentParser((ScriptArguments, SFTConfig))
|
||||
script_args, training_args = parser.parse_args_into_dataclasses()
|
||||
peft_config = LoraConfig(
|
||||
r=script_args.lora_r,
|
||||
lora_alpha=script_args.lora_alpha,
|
||||
lora_dropout=script_args.lora_dropout,
|
||||
target_modules=["q_proj", "v_proj"],
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
|
||||
if training_args.group_by_length and training_args.packing:
|
||||
raise ValueError("Cannot use both packing and group by length")
|
||||
|
||||
# `gradient_checkpointing` was True by default until `1f3314`, but it's actually not used.
|
||||
# `gradient_checkpointing=True` will cause `Variable._execution_engine.run_backward`.
|
||||
if training_args.gradient_checkpointing:
|
||||
raise ValueError("gradient_checkpointing not supported")
|
||||
|
||||
set_seed(training_args.seed)
|
||||
|
||||
|
||||
def chars_token_ratio(dataset, tokenizer, nb_examples=400):
|
||||
"""
|
||||
Estimate the average number of characters per token in the dataset.
|
||||
"""
|
||||
total_characters, total_tokens = 0, 0
|
||||
for _, example in tqdm(zip(range(nb_examples), iter(dataset)), total=nb_examples):
|
||||
text = prepare_sample_text(example)
|
||||
total_characters += len(text)
|
||||
if tokenizer.is_fast:
|
||||
total_tokens += len(tokenizer(text).tokens())
|
||||
else:
|
||||
total_tokens += len(tokenizer.tokenize(text))
|
||||
|
||||
return total_characters / total_tokens
|
||||
|
||||
|
||||
def print_trainable_parameters(model):
|
||||
"""
|
||||
Prints the number of trainable parameters in the model.
|
||||
"""
|
||||
trainable_params = 0
|
||||
all_param = 0
|
||||
for _, param in model.named_parameters():
|
||||
all_param += param.numel()
|
||||
if param.requires_grad:
|
||||
trainable_params += param.numel()
|
||||
print(
|
||||
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
|
||||
)
|
||||
|
||||
|
||||
def prepare_sample_text(example):
|
||||
"""Prepare the text from a sample of the dataset."""
|
||||
text = f"Question: {example['question']}\n\nAnswer: {example['response_j']}"
|
||||
return text
|
||||
|
||||
|
||||
def create_datasets(tokenizer, args, seed=None):
|
||||
dataset = load_dataset(
|
||||
args.dataset_name,
|
||||
data_dir=args.subset,
|
||||
split=args.split,
|
||||
use_auth_token=True,
|
||||
num_proc=args.num_workers if not args.streaming else None,
|
||||
streaming=args.streaming,
|
||||
)
|
||||
if args.streaming:
|
||||
print("Loading the dataset in streaming mode")
|
||||
valid_data = dataset.take(args.size_valid_set)
|
||||
train_data = dataset.skip(args.size_valid_set)
|
||||
train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=seed)
|
||||
else:
|
||||
dataset = dataset.train_test_split(test_size=0.005, seed=seed)
|
||||
train_data = dataset["train"]
|
||||
valid_data = dataset["test"]
|
||||
print(f"Size of the train set: {len(train_data)}. Size of the validation set: {len(valid_data)}")
|
||||
|
||||
chars_per_token = chars_token_ratio(train_data, tokenizer)
|
||||
print(f"The character to token ratio of the dataset is: {chars_per_token:.2f}")
|
||||
|
||||
train_dataset = ConstantLengthDataset(
|
||||
tokenizer,
|
||||
train_data,
|
||||
formatting_func=prepare_sample_text,
|
||||
infinite=True,
|
||||
seq_length=args.seq_length,
|
||||
chars_per_token=chars_per_token,
|
||||
)
|
||||
valid_dataset = ConstantLengthDataset(
|
||||
tokenizer,
|
||||
valid_data,
|
||||
formatting_func=prepare_sample_text,
|
||||
infinite=False,
|
||||
seq_length=args.seq_length,
|
||||
chars_per_token=chars_per_token,
|
||||
)
|
||||
return train_dataset, valid_dataset
|
||||
|
||||
|
||||
bnb_config = None
|
||||
if script_args.use_bnb:
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
base_model = AutoModelForCausalLM.from_pretrained(
|
||||
script_args.model_name,
|
||||
quantization_config=bnb_config,
|
||||
device_map={"": Accelerator().local_process_index},
|
||||
trust_remote_code=True,
|
||||
use_auth_token=True,
|
||||
)
|
||||
base_model.config.use_cache = False
|
||||
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(script_args.model_name, trust_remote_code=True)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
tokenizer.padding_side = "right" # Fix weird overflow issue with fp16 training
|
||||
|
||||
train_dataset, eval_dataset = create_datasets(tokenizer, script_args, seed=training_args.seed)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=base_model,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
peft_config=peft_config,
|
||||
max_length=None,
|
||||
formatting_func=prepare_sample_text,
|
||||
processing_class=tokenizer,
|
||||
args=training_args,
|
||||
)
|
||||
trainer.train()
|
||||
trainer.save_model(training_args.output_dir)
|
||||
|
||||
output_dir = os.path.join(training_args.output_dir, "final_checkpoint")
|
||||
trainer.model.save_pretrained(output_dir)
|
||||
|
||||
# Free memory for merging weights
|
||||
del base_model
|
||||
if is_torch_xpu_available():
|
||||
torch.xpu.empty_cache()
|
||||
elif is_torch_npu_available():
|
||||
torch.npu.empty_cache()
|
||||
else:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
model = AutoPeftModelForCausalLM.from_pretrained(output_dir, device_map="auto", dtype=torch.bfloat16)
|
||||
model = model.merge_and_unload()
|
||||
|
||||
output_merged_dir = os.path.join(training_args.output_dir, "final_merged_checkpoint")
|
||||
model.save_pretrained(output_merged_dir, safe_serialization=True)
|
@ -1,7 +0,0 @@
|
||||
# De-detoxifying language models
|
||||
|
||||
To run this code, do the following:
|
||||
|
||||
```shell
|
||||
ACCELERATE_LOG_LEVEL=info accelerate launch --config_file {CONFIG} examples/research_projects/toxicity/scripts/gpt-j-6b-toxicity.py --log_with wandb
|
||||
```
|
@ -1,146 +0,0 @@
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
|
||||
import evaluate
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, is_torch_npu_available, is_torch_xpu_available
|
||||
|
||||
|
||||
toxicity = evaluate.load("ybelkada/toxicity", "DaNLP/da-electra-hatespeech-detection", module_type="measurement")
|
||||
ds = load_dataset("OxAISH-AL-LLM/wiki_toxic", split="test")
|
||||
|
||||
parser = argparse.ArgumentParser(description="Evaluate de-toxified models")
|
||||
parser.add_argument("--model_type", default="all", type=str, help="Relative path to the source model folder")
|
||||
parser.add_argument("--output_file", default="toxicity.csv", type=str, help="Relative path to the source model folder")
|
||||
parser.add_argument("--batch_size", default=64, type=int, help="Batch size")
|
||||
parser.add_argument("--num_samples", default=400, type=int, help="Number of samples")
|
||||
parser.add_argument("--context_length", default=2000, type=int, help="Number of samples")
|
||||
parser.add_argument("--max_new_tokens", default=30, type=int, help="Max new tokens for generation")
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
if args.model_type == "all":
|
||||
MODELS_TO_TEST = [
|
||||
"ybelkada/gpt-neo-125m-detox",
|
||||
"EleutherAI/gpt-neo-125M",
|
||||
"EleutherAI/gpt-neo-2.7B",
|
||||
"ybelkada/gpt-neo-2.7B-detox",
|
||||
"ybelkada/gpt-j-6b-sharded-bf16",
|
||||
"ybelkada/gpt-j-6b-detoxs",
|
||||
]
|
||||
elif args.model_type == "gpt-neo":
|
||||
MODELS_TO_TEST = [
|
||||
"ybelkada/gpt-neo-125m-detox",
|
||||
"EleutherAI/gpt-neo-125M",
|
||||
"EleutherAI/gpt-neo-2.7B",
|
||||
"ybelkada/gpt-neo-2.7B-detox",
|
||||
]
|
||||
elif args.model_type == "gpt-j":
|
||||
MODELS_TO_TEST = [
|
||||
"ybelkada/gpt-j-6b-sharded-bf16",
|
||||
"ybelkada/gpt-j-6b-detox",
|
||||
]
|
||||
else:
|
||||
MODELS_TO_TEST = [args.model_type]
|
||||
NUM_SAMPLES = args.num_samples
|
||||
BATCH_SIZE = args.batch_size
|
||||
output_file = args.output_file
|
||||
max_new_tokens = args.max_new_tokens
|
||||
context_length = args.context_length
|
||||
if is_torch_xpu_available():
|
||||
device = torch.xpu.current_device()
|
||||
elif is_torch_npu_available():
|
||||
device = torch.npu.current_device()
|
||||
else:
|
||||
device = torch.cuda.current_device() if torch.cuda.is_available() else "cpu"
|
||||
|
||||
# consider only toxic prompts
|
||||
ds = ds.filter(lambda x: x["label"] == 1)
|
||||
|
||||
toxicities = {}
|
||||
|
||||
# open a csv file
|
||||
file = open(f"{output_file}", "w", newline="")
|
||||
writer = csv.writer(file)
|
||||
# add first rows
|
||||
writer.writerow(["model_id", "mean_toxicity", "std_toxicity"])
|
||||
|
||||
|
||||
for model_id in tqdm(MODELS_TO_TEST):
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, device_map={"": device}, dtype=torch.bfloat16)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
tokenizer.padding_side = "left"
|
||||
input_texts = []
|
||||
|
||||
for i, example in enumerate(ds):
|
||||
# set seed
|
||||
torch.manual_seed(42)
|
||||
|
||||
input_text = example["comment_text"]
|
||||
input_texts.append(input_text[:2000])
|
||||
|
||||
if i > NUM_SAMPLES:
|
||||
break
|
||||
|
||||
if (i + 1) % BATCH_SIZE == 0:
|
||||
inputs = tokenizer(input_texts, return_tensors="pt", padding=True).to(device)
|
||||
inputs.input_ids = inputs.input_ids[:context_length]
|
||||
inputs.attention_mask = inputs.attention_mask[:context_length]
|
||||
outputs = model.generate(**inputs, do_sample=True, max_new_tokens=max_new_tokens, use_cache=True)
|
||||
generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
generated_texts = [
|
||||
generated_text.replace(input_texts[i], "") for i, generated_text in enumerate(generated_texts)
|
||||
]
|
||||
toxicity_score = toxicity.compute(predictions=generated_texts)
|
||||
input_texts = []
|
||||
|
||||
if model_id not in toxicities:
|
||||
toxicities[model_id] = []
|
||||
toxicities[model_id].extend(toxicity_score["toxicity"])
|
||||
|
||||
# last batch
|
||||
inputs = tokenizer(input_texts, return_tensors="pt", padding=True).to(device)
|
||||
outputs = model.generate(**inputs, do_sample=True, max_new_tokens=30)
|
||||
generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
||||
generated_texts = [generated_text.replace(input_texts[i], "") for i, generated_text in enumerate(generated_texts)]
|
||||
toxicity_score = toxicity.compute(predictions=generated_texts)
|
||||
toxicities[model_id].extend(toxicity_score["toxicity"])
|
||||
|
||||
# compute mean & std using np
|
||||
mean = np.mean(toxicities[model_id])
|
||||
std = np.std(toxicities[model_id])
|
||||
|
||||
# save to file
|
||||
writer.writerow([model_id, mean, std])
|
||||
|
||||
# print
|
||||
print(f"Model: {model_id} - Mean: {mean} - Std: {std}")
|
||||
|
||||
model = None
|
||||
if is_torch_xpu_available():
|
||||
torch.xpu.empty_cache()
|
||||
elif is_torch_npu_available():
|
||||
torch.npu.empty_cache()
|
||||
else:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# close file
|
||||
file.close()
|
@ -1,245 +0,0 @@
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from torch.optim import Adam
|
||||
from tqdm import tqdm
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
HfArgumentParser,
|
||||
RobertaForSequenceClassification,
|
||||
RobertaTokenizer,
|
||||
set_seed,
|
||||
)
|
||||
|
||||
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, create_reference_model
|
||||
from trl.core import LengthSampler
|
||||
|
||||
|
||||
tqdm.pandas()
|
||||
|
||||
########################################################################
|
||||
# This is a fully working simple example to use trl with accelerate.
|
||||
#
|
||||
# This example fine-tunes a GPTJ model to generate less toxic contents
|
||||
# by using allenai/real-toxicity-prompts dataset. We use PPO
|
||||
# (proximal policy optimization) to optimize the model.
|
||||
# in any of the following settings (with the same script):
|
||||
# - single CPU or single GPU
|
||||
# - multi GPUS (using PyTorch distributed mode)
|
||||
# - multi GPUS (using DeepSpeed ZeRO-Offload stages 1 & 2)
|
||||
# - fp16 (mixed-precision) or fp32 (normal precision)
|
||||
#
|
||||
# To run it in each of these various modes, first initialize the accelerate
|
||||
# configuration with `accelerate config`
|
||||
#
|
||||
########################################################################
|
||||
|
||||
|
||||
# We first define the configuration of the experiment, defining the model, the dataset,
|
||||
# the training parameters, and the PPO parameters.
|
||||
# Check the default arguments in the `PPOConfig` class for more details.
|
||||
# If you want to log with tensorboard, add the kwarg
|
||||
# `project_kwargs={"logging_dir": PATH_TO_LOGS}` to the PPOConfig.
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
"""
|
||||
The name of the Casual LM model we wish to fine-tune with PPO
|
||||
"""
|
||||
|
||||
# NOTE: gpt2 models use Conv1D instead of Linear layers which are not yet supported in 8 bit mode
|
||||
# models like gpt-neo* models are more suitable.
|
||||
model_name: Optional[str] = field(default="ybelkada/gpt-j-6b-sharded-bf16", metadata={"help": "the model name"})
|
||||
log_with: Optional[str] = field(default=None, metadata={"help": "use 'wandb' to log with wandb"})
|
||||
learning_rate: Optional[float] = field(default=(1.47e-5) * 2, metadata={"help": "the learning rate"})
|
||||
mini_batch_size: Optional[int] = field(default=4, metadata={"help": "the PPO minibatch size"})
|
||||
batch_size: Optional[int] = field(default=16, metadata={"help": "the batch size"})
|
||||
gradient_accumulation_steps: Optional[int] = field(
|
||||
default=1, metadata={"help": "the number of gradient accumulation steps"}
|
||||
)
|
||||
model_save_path: Optional[str] = field(
|
||||
default="./gpt-j-6B-detoxified-long-context-26-shl-1e4-final",
|
||||
metadata={"help": "the path to save the model"},
|
||||
)
|
||||
|
||||
|
||||
parser = HfArgumentParser(ScriptArguments)
|
||||
script_args = parser.parse_args_into_dataclasses()[0]
|
||||
|
||||
config = PPOConfig(
|
||||
model_name=script_args.model_name,
|
||||
learning_rate=script_args.learning_rate,
|
||||
log_with=script_args.log_with,
|
||||
ppo_epochs=100,
|
||||
mini_batch_size=script_args.mini_batch_size,
|
||||
batch_size=script_args.batch_size,
|
||||
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
|
||||
)
|
||||
|
||||
|
||||
# Below is an example function to build the dataset. In our case, we use the IMDB dataset
|
||||
# from the `datasets` library. One should customize this function to train the model on
|
||||
# its own dataset.
|
||||
def build_dataset(
|
||||
config, dataset_name="allenai/real-toxicity-prompts", input_min_text_length=5, input_max_text_length=10
|
||||
):
|
||||
"""
|
||||
Build dataset for training. This builds the dataset from `load_dataset`, one should
|
||||
customize this function to train the model on its own dataset.
|
||||
|
||||
Args:
|
||||
config (`PPOConfig`):
|
||||
The configuration of the PPO training.
|
||||
dataset_name (`str`):
|
||||
The name of the dataset to be loaded.
|
||||
input_min_text_length (`int`, defaults to 5):
|
||||
The minimum length of the input text.
|
||||
input_max_text_length (`int`, defaults to 10):
|
||||
The maximum length of the input text.
|
||||
|
||||
Returns:
|
||||
dataloader (`torch.utils.data.DataLoader`):
|
||||
The dataloader for the dataset.
|
||||
"""
|
||||
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
ds = load_dataset(dataset_name, split="train")
|
||||
|
||||
def filter_fn(sample):
|
||||
toxicity = sample["prompt"]["toxicity"]
|
||||
return toxicity is not None and toxicity > 0.3
|
||||
|
||||
ds = ds.filter(filter_fn, batched=False)
|
||||
|
||||
input_size = LengthSampler(input_min_text_length, input_max_text_length)
|
||||
|
||||
def tokenize(sample):
|
||||
prompt = sample["prompt"]["text"]
|
||||
continuation = sample["continuation"]["text"]
|
||||
|
||||
sample["input_ids"] = tokenizer.encode(prompt + continuation)[: input_size()]
|
||||
sample["query"] = tokenizer.decode(sample["input_ids"])
|
||||
return sample
|
||||
|
||||
ds = ds.map(tokenize, batched=False)
|
||||
ds.set_format(type="torch")
|
||||
|
||||
ds = ds.train_test_split(test_size=0.2, shuffle=False)["train"]
|
||||
|
||||
return ds
|
||||
|
||||
|
||||
# We retrieve the dataloader by calling the `build_dataset` function.
|
||||
min_input_length = 30
|
||||
max_input_length = 40
|
||||
dataset = build_dataset(config, input_min_text_length=min_input_length, input_max_text_length=max_input_length)
|
||||
|
||||
|
||||
def collator(data):
|
||||
return {key: [d[key] for d in data] for key in data[0]}
|
||||
|
||||
|
||||
# set seed before initializing value head for deterministic eval
|
||||
set_seed(config.seed)
|
||||
|
||||
# Now let's build the model, the reference model, and the tokenizer. We first load the model
|
||||
# in bfloat16 to save memory using `transformers`.
|
||||
model = AutoModelForCausalLM.from_pretrained(config.model_name, dtype=torch.bfloat16)
|
||||
# And then we pass the loaded model to `AutoModelForCausalLMWithValueHead`.
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
||||
|
||||
# We create a reference model by sharing 20 layers
|
||||
ref_model = create_reference_model(model, num_shared_layers=20)
|
||||
|
||||
# We make sure to use `Adam` optimizer on the model parameters that require gradients.
|
||||
optimizer = Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=config.learning_rate)
|
||||
|
||||
# GPT-2 / GPT-J tokenizer has a pad token, but it is not eos_token by default. We need to set it to eos_token.
|
||||
# only for this model.
|
||||
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
# We then build the PPOTrainer, passing the model, the reference model, the tokenizer
|
||||
ppo_trainer = PPOTrainer(
|
||||
config,
|
||||
model,
|
||||
ref_model=ref_model,
|
||||
tokenizer=tokenizer,
|
||||
dataset=dataset,
|
||||
data_collator=collator,
|
||||
optimizer=optimizer,
|
||||
)
|
||||
|
||||
# We then build the reward pipeline, we will use the toxicity model to compute the reward.
|
||||
# We first load the toxicity model and tokenizer.
|
||||
toxicity_model_id = "facebook/roberta-hate-speech-dynabench-r4-target"
|
||||
toxicity_tokenizer = RobertaTokenizer.from_pretrained(toxicity_model_id)
|
||||
# We load the toxicity model in fp16 to save memory.
|
||||
toxicity_model = RobertaForSequenceClassification.from_pretrained(toxicity_model_id, dtype=torch.float16).to(
|
||||
ppo_trainer.accelerator.device
|
||||
)
|
||||
|
||||
|
||||
# We then define the arguments to pass to the `generate` function. These arguments
|
||||
# are passed to the `generate` function of the PPOTrainer, which is a wrapper around
|
||||
# the `generate` function of the trained model.
|
||||
generation_kwargs = {
|
||||
"min_length": -1,
|
||||
"top_k": 0.0,
|
||||
"top_p": 1.0,
|
||||
"do_sample": True,
|
||||
"pad_token_id": tokenizer.eos_token_id,
|
||||
}
|
||||
output_min_length = 20
|
||||
output_max_length = 30
|
||||
output_length_sampler = LengthSampler(output_min_length, output_max_length)
|
||||
|
||||
model_save_path = script_args.model_save_path
|
||||
|
||||
for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
|
||||
query_tensors = batch["input_ids"]
|
||||
|
||||
# Get response from the policy model
|
||||
response_tensors = []
|
||||
for query in query_tensors:
|
||||
gen_len = output_length_sampler()
|
||||
generation_kwargs["max_new_tokens"] = gen_len
|
||||
response = ppo_trainer.generate(query, **generation_kwargs)
|
||||
response_tensors.append(response.squeeze()[-gen_len:])
|
||||
batch["response"] = [tokenizer.decode(r.squeeze()) for r in response_tensors]
|
||||
|
||||
# Compute sentiment score
|
||||
texts = batch["response"]
|
||||
toxicity_inputs = toxicity_tokenizer(texts, padding=True, truncation=True, return_tensors="pt").to(
|
||||
ppo_trainer.accelerator.device
|
||||
)
|
||||
logits = toxicity_model(**toxicity_inputs).logits.float()
|
||||
toxicity_labels = (logits[:, 0]).tolist()
|
||||
|
||||
rewards = [torch.tensor(output) for output in toxicity_labels]
|
||||
|
||||
# Run PPO step
|
||||
stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
|
||||
ppo_trainer.log_stats(stats, batch, rewards)
|
||||
|
||||
# Save model every 100 epochs
|
||||
if epoch % 100 == 0:
|
||||
if ppo_trainer.accelerator.is_main_process:
|
||||
ppo_trainer.save_pretrained(model_save_path)
|
@ -90,26 +90,32 @@ vlm = [
|
||||
"num2words==0.5.14"
|
||||
]
|
||||
dev = [
|
||||
# bco
|
||||
"scikit-learn",
|
||||
"joblib",
|
||||
# deepspeed
|
||||
"deepspeed>=0.14.4",
|
||||
# judges
|
||||
"openai>=1.23.2",
|
||||
"llm-blender>=0.0.2",
|
||||
# liger
|
||||
"liger-kernel>=0.6.2",
|
||||
# peft
|
||||
"peft>=0.8.0",
|
||||
# quality
|
||||
"pre-commit",
|
||||
"hf-doc-builder",
|
||||
# quantization
|
||||
"bitsandbytes",
|
||||
# scikit: included in bco
|
||||
# test
|
||||
"parameterized",
|
||||
"pytest-cov",
|
||||
"pytest-rerunfailures==15.1",
|
||||
"pytest-xdist",
|
||||
"pytest",
|
||||
"vllm==0.10.2",
|
||||
"fastapi",
|
||||
"pydantic",
|
||||
"requests",
|
||||
"uvicorn",
|
||||
# vllm: not included in dev by default due to CUDA error; see GH-4228
|
||||
# vlm
|
||||
"Pillow",
|
||||
"torchvision",
|
||||
"num2words==0.5.14"
|
||||
|
@ -1,158 +0,0 @@
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from datetime import date
|
||||
|
||||
from tabulate import tabulate
|
||||
|
||||
|
||||
MAX_LEN_MESSAGE = 2900 # slack endpoint has a limit of 3001 characters
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--slack_channel_name", default="trl-push-examples-ci")
|
||||
parser.add_argument("--text_file_name", required=True)
|
||||
|
||||
|
||||
def main(text_file_name, slack_channel_name=None):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
message = ""
|
||||
|
||||
if os.path.isfile(text_file_name):
|
||||
final_results = {}
|
||||
|
||||
try:
|
||||
with open(text_file_name) as file:
|
||||
for line in file:
|
||||
result, config_name = line.strip().split(",")
|
||||
config_name = config_name.split("/")[-1].split(".yaml")[0]
|
||||
final_results[config_name] = int(result)
|
||||
except Exception as e:
|
||||
logger.error(f"Error reading file {text_file_name}: {str(e)}")
|
||||
final_results = {}
|
||||
|
||||
no_error_payload = {
|
||||
"type": "section",
|
||||
"text": {
|
||||
"type": "plain_text",
|
||||
"text": "🌞 There were no failures on the example tests!"
|
||||
if not len(final_results) == 0
|
||||
else "Something went wrong there is at least one empty file - please check GH action results.",
|
||||
"emoji": True,
|
||||
},
|
||||
}
|
||||
|
||||
total_num_failed = sum(final_results.values())
|
||||
else:
|
||||
no_error_payload = {
|
||||
"type": "section",
|
||||
"text": {
|
||||
"type": "plain_text",
|
||||
"text": "❌ Something is wrong with the workflow please check ASAP!"
|
||||
"Something went wrong there is no text file being produced. Please check ASAP.",
|
||||
"emoji": True,
|
||||
},
|
||||
}
|
||||
|
||||
total_num_failed = 0
|
||||
|
||||
test_type_name = text_file_name.replace(".txt", "").replace("temp_results_", "").replace("_", " ").title()
|
||||
|
||||
payload = [
|
||||
{
|
||||
"type": "header",
|
||||
"text": {
|
||||
"type": "plain_text",
|
||||
"text": "🤗 Results of the {} TRL {} example tests.".format(
|
||||
os.environ.get("TEST_TYPE", ""), test_type_name
|
||||
),
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
if total_num_failed > 0:
|
||||
message += f"{total_num_failed} failed tests for example tests!"
|
||||
|
||||
for test_name, failed in final_results.items():
|
||||
failed_table = tabulate(
|
||||
[[test_name, "✅" if not failed else "❌"]],
|
||||
headers=["Test Name", "Status"],
|
||||
showindex="always",
|
||||
tablefmt="grid",
|
||||
maxcolwidths=[12],
|
||||
)
|
||||
message += "\n```\n" + failed_table + "\n```"
|
||||
|
||||
print(f"### {message}")
|
||||
else:
|
||||
payload.append(no_error_payload)
|
||||
|
||||
if os.environ.get("TEST_TYPE", "") != "":
|
||||
try:
|
||||
from slack_sdk import WebClient
|
||||
except ImportError:
|
||||
logger.error("slack_sdk is not installed. Please install it to use Slack integration.")
|
||||
return
|
||||
|
||||
if len(message) > MAX_LEN_MESSAGE:
|
||||
print(f"Truncating long message from {len(message)} to {MAX_LEN_MESSAGE}")
|
||||
message = message[:MAX_LEN_MESSAGE] + "..."
|
||||
|
||||
if len(message) != 0:
|
||||
md_report = {
|
||||
"type": "section",
|
||||
"text": {"type": "mrkdwn", "text": message},
|
||||
}
|
||||
payload.append(md_report)
|
||||
action_button = {
|
||||
"type": "section",
|
||||
"text": {"type": "mrkdwn", "text": "*For more details:*"},
|
||||
"accessory": {
|
||||
"type": "button",
|
||||
"text": {"type": "plain_text", "text": "Check Action results", "emoji": True},
|
||||
"url": f"https://github.com/huggingface/trl/actions/runs/{os.environ['GITHUB_RUN_ID']}",
|
||||
},
|
||||
}
|
||||
payload.append(action_button)
|
||||
|
||||
date_report = {
|
||||
"type": "context",
|
||||
"elements": [
|
||||
{
|
||||
"type": "plain_text",
|
||||
"text": f"On Push - main {os.environ.get('TEST_TYPE')} test results for {date.today()}",
|
||||
},
|
||||
],
|
||||
}
|
||||
payload.append(date_report)
|
||||
|
||||
print(payload)
|
||||
|
||||
try:
|
||||
client = WebClient(token=os.environ.get("SLACK_API_TOKEN"))
|
||||
response = client.chat_postMessage(channel=f"#{slack_channel_name}", text=message, blocks=payload)
|
||||
if response["ok"]:
|
||||
logger.info("Message sent successfully to Slack.")
|
||||
else:
|
||||
logger.error(f"Failed to send message to Slack: {response['error']}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending message to Slack: {str(e)}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
main(args.text_file_name, args.slack_channel_name)
|
@ -12,17 +12,24 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from huggingface_hub import whoami
|
||||
import gc
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
|
||||
model_name = "unsloth/Llama-3.2-3B"
|
||||
tokenizer_name = "unsloth/Llama-3.2-3B"
|
||||
dataset_name = "WillHeld/top_v2"
|
||||
@pytest.fixture(autouse=True)
|
||||
def cleanup_gpu():
|
||||
"""
|
||||
Automatically cleanup GPU memory after each test.
|
||||
|
||||
output_root_dir = "./checkpoints/"
|
||||
hub_model_id = f"{whoami()['name']}/layerskip-{model_name.split('/')[1]}-{dataset_name.split('/')[1]}"
|
||||
output_dir = f"{output_root_dir}/{hub_model_id}"
|
||||
|
||||
per_device_train_batch_size = 8
|
||||
gradient_accumulation_steps = 1
|
||||
learning_rate = 2e-5
|
||||
This fixture helps prevent CUDA out of memory errors when running tests in parallel
|
||||
with pytest-xdist by ensuring models and tensors are properly garbage collected
|
||||
and GPU memory caches are cleared between tests.
|
||||
"""
|
||||
yield
|
||||
# Cleanup after test
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
@ -118,6 +118,7 @@ class TestGRPOTrainerSlow(TrlTestCase):
|
||||
max_completion_length=self.max_length,
|
||||
report_to="none",
|
||||
logging_strategy="no",
|
||||
loss_type="bnpo", # liger-kernel does not support "dapo" default; see https://github.com/linkedin/Liger-Kernel/issues/620
|
||||
)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name)
|
||||
|
@ -412,12 +412,12 @@ class TestSFTTrainerSlow(TrlTestCase):
|
||||
eval_dataset=self.eval_dataset,
|
||||
)
|
||||
|
||||
# Register cleanup now that we have the trainer
|
||||
self.addCleanup(cleanup_liger_patches, trainer)
|
||||
|
||||
trainer.train()
|
||||
|
||||
release_memory(trainer.model, trainer)
|
||||
# Ensure cleanup of liger patches after the test
|
||||
try:
|
||||
trainer.train()
|
||||
release_memory(trainer.model, trainer)
|
||||
finally:
|
||||
cleanup_liger_patches(trainer)
|
||||
|
||||
@parameterized.expand(list(itertools.product(MODELS_TO_TEST, PACKING_OPTIONS)))
|
||||
@require_torch_accelerator
|
||||
|
@ -1,113 +0,0 @@
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import torch
|
||||
from transformers import AutoTokenizer, GenerationConfig
|
||||
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
from trl.core import LengthSampler
|
||||
from trl.extras import BestOfNSampler
|
||||
|
||||
from .testing_utils import TrlTestCase
|
||||
|
||||
|
||||
def queries_to_scores(list_of_strings):
|
||||
return [torch.rand(1).item() for _ in list_of_strings]
|
||||
|
||||
|
||||
class TestBestOfNSampler(TrlTestCase):
|
||||
"""
|
||||
Tests the BestOfNSampler class
|
||||
"""
|
||||
|
||||
ref_model_name = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
|
||||
output_length_sampler = LengthSampler(2, 6)
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained(ref_model_name)
|
||||
tokenizer = AutoTokenizer.from_pretrained(ref_model_name)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
output_length_sampler = LengthSampler(2, 6)
|
||||
|
||||
def test_different_input_types(self):
|
||||
r"""
|
||||
Tests if the different input types normalizer works
|
||||
"""
|
||||
|
||||
generation_config = GenerationConfig(
|
||||
min_length=-1,
|
||||
top_k=0.0,
|
||||
top_p=1.0,
|
||||
do_sample=True,
|
||||
pad_token_id=self.tokenizer.eos_token_id,
|
||||
)
|
||||
|
||||
output_length_sampler = LengthSampler(2, 6)
|
||||
|
||||
best_of_n = BestOfNSampler(
|
||||
self.model,
|
||||
self.tokenizer,
|
||||
queries_to_scores,
|
||||
length_sampler=output_length_sampler,
|
||||
generation_config=generation_config,
|
||||
)
|
||||
|
||||
queries = ["hello world", "goodbye world"]
|
||||
tokenized_queries = [self.tokenizer.encode(query) for query in queries]
|
||||
|
||||
various_queries_formats = [
|
||||
(tokenized_queries[0], 1),
|
||||
(tokenized_queries, 2),
|
||||
(torch.tensor(tokenized_queries[1]), 1),
|
||||
([torch.tensor(query) for query in tokenized_queries], 2),
|
||||
]
|
||||
|
||||
for q, expected_length in various_queries_formats:
|
||||
results = best_of_n.generate(q)
|
||||
assert isinstance(results, list)
|
||||
assert len(results) == expected_length
|
||||
|
||||
def test_different_sample_sizes_and_n_candidates_values(self):
|
||||
r"""
|
||||
Tests different sample sizes and n_candidates values
|
||||
"""
|
||||
generation_config = GenerationConfig(
|
||||
min_length=-1,
|
||||
top_k=0.0,
|
||||
top_p=1.0,
|
||||
do_sample=True,
|
||||
pad_token_id=self.tokenizer.eos_token_id,
|
||||
)
|
||||
|
||||
output_length_sampler = LengthSampler(6, 10)
|
||||
|
||||
for sample_value, n_candidates_values, expected in [
|
||||
(4, 2, 2),
|
||||
(10, 3, 3),
|
||||
(6, 4, 4),
|
||||
]:
|
||||
best_of_n = BestOfNSampler(
|
||||
self.model,
|
||||
self.tokenizer,
|
||||
queries_to_scores,
|
||||
length_sampler=output_length_sampler,
|
||||
generation_config=generation_config,
|
||||
sample_size=sample_value,
|
||||
n_candidates=n_candidates_values,
|
||||
)
|
||||
|
||||
queries = ["hello world", "troll the world"]
|
||||
tokenized_queries = [self.tokenizer.encode(query) for query in queries]
|
||||
results = best_of_n.generate(tokenized_queries)
|
||||
for result in results:
|
||||
assert len(result) == expected
|
@ -396,6 +396,29 @@ class TestApplyChatTemplate(TrlTestCase):
|
||||
assert isinstance(result["label"], bool)
|
||||
assert result["label"] == example["label"]
|
||||
|
||||
def test_apply_chat_template_with_chat_template_kwargs(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3ForCausalLM")
|
||||
|
||||
example = {
|
||||
"prompt": [{"role": "user", "content": "What color is the sky?"}],
|
||||
# with this tokenizer, when you pass enable_thinking=False, it will add "<think>\n\n</think>\n\n"
|
||||
"chat_template_kwargs": {"enable_thinking": False},
|
||||
}
|
||||
result = apply_chat_template(example, tokenizer)
|
||||
|
||||
# docstyle-ignore
|
||||
expected = textwrap.dedent("""\
|
||||
<|im_start|>user
|
||||
What color is the sky?<|im_end|>
|
||||
<|im_start|>assistant
|
||||
<think>
|
||||
|
||||
</think>
|
||||
|
||||
""")
|
||||
|
||||
assert result["prompt"] == expected
|
||||
|
||||
def test_apply_chat_template_with_tools(self):
|
||||
tokenizer = AutoProcessor.from_pretrained("trl-internal-testing/tiny-LlamaForCausalLM-3.2")
|
||||
|
||||
|
@ -14,6 +14,7 @@
|
||||
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
from datasets import Dataset, load_dataset
|
||||
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
|
||||
|
||||
@ -23,6 +24,7 @@ from trl.models.utils import ChatMlSpecialTokens, clone_chat_template, setup_cha
|
||||
from .testing_utils import TrlTestCase
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("ignore::FutureWarning")
|
||||
class TestDatasetFormatting(TrlTestCase):
|
||||
def setup_method(self):
|
||||
self.llama_tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-MistralForCausalLM-0.1")
|
||||
|
@ -33,12 +33,18 @@ from transformers import (
|
||||
from transformers.testing_utils import (
|
||||
get_device_properties,
|
||||
require_liger_kernel,
|
||||
require_torch_gpu_if_bnb_not_multi_backend_enabled,
|
||||
)
|
||||
|
||||
from trl import DPOConfig, DPOTrainer, FDivergenceType
|
||||
|
||||
from .testing_utils import TrlTestCase, require_bitsandbytes, require_no_wandb, require_peft, require_vision
|
||||
from .testing_utils import (
|
||||
TrlTestCase,
|
||||
require_bitsandbytes,
|
||||
require_no_wandb,
|
||||
require_peft,
|
||||
require_torch_gpu_if_bnb_not_multi_backend_enabled,
|
||||
require_vision,
|
||||
)
|
||||
|
||||
|
||||
if is_vision_available():
|
||||
|
@ -1471,47 +1471,6 @@ class TestGRPOTrainer(TrlTestCase):
|
||||
new_param = trainer.model.get_parameter(n)
|
||||
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
|
||||
|
||||
@require_vision
|
||||
def test_training_vlm_and_prompt_truncation(self):
|
||||
# If not handled properly, prompt truncation may truncate image token
|
||||
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train")
|
||||
|
||||
def reward_func(completions, **kwargs):
|
||||
"""Reward function that rewards longer completions."""
|
||||
return [float(len(completion[0]["content"])) for completion in completions]
|
||||
|
||||
training_args = GRPOConfig(
|
||||
output_dir=self.tmp_dir,
|
||||
learning_rate=0.1, # increase the learning rate to speed up the test
|
||||
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
|
||||
num_generations=3, # reduce the number of generations to reduce memory usage
|
||||
max_completion_length=8, # reduce the completion length to reduce memory usage
|
||||
max_prompt_length=18,
|
||||
report_to="none",
|
||||
)
|
||||
trainer = GRPOTrainer(
|
||||
model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
|
||||
reward_funcs=reward_func,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
)
|
||||
|
||||
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
||||
|
||||
trainer.train()
|
||||
|
||||
assert trainer.state.log_history[-1]["train_loss"] is not None
|
||||
|
||||
# Check that the params have changed
|
||||
# Because of the way the tiny models are initialized, the gradient does not flow properly through the
|
||||
# vision parts of the model, so we skip them. Ideally, we should fix the init of these models.
|
||||
params_to_skip = ("model.visual.",)
|
||||
for n, param in previous_trainable_params.items():
|
||||
if n.startswith(params_to_skip):
|
||||
continue
|
||||
new_param = trainer.model.get_parameter(n)
|
||||
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
("trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",),
|
||||
|
@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
import time
|
||||
|
||||
import pytest
|
||||
@ -59,6 +60,9 @@ class TestJudges(TrlTestCase):
|
||||
raise ValueError("Failed to load PairRMJudge")
|
||||
|
||||
@require_llm_blender
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info[:3] == (3, 13, 8), reason="Python 3.13.8 has a bug in inspect.BlockFinder (cpython GH-139783)"
|
||||
)
|
||||
def test_pair_rm_judge(self):
|
||||
judge = self.load_pair_rm_judge()
|
||||
prompts, completions = self._get_prompts_and_pairwise_completions()
|
||||
@ -68,6 +72,9 @@ class TestJudges(TrlTestCase):
|
||||
assert ranks == [0, 1]
|
||||
|
||||
@require_llm_blender
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info[:3] == (3, 13, 8), reason="Python 3.13.8 has a bug in inspect.BlockFinder (cpython GH-139783)"
|
||||
)
|
||||
def test_pair_rm_judge_return_scores(self):
|
||||
judge = self.load_pair_rm_judge()
|
||||
prompts, completions = self._get_prompts_and_pairwise_completions()
|
||||
|
@ -16,12 +16,11 @@ import os
|
||||
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM
|
||||
from transformers.testing_utils import require_torch_gpu_if_bnb_not_multi_backend_enabled
|
||||
from transformers.utils import is_peft_available
|
||||
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
from .testing_utils import TrlTestCase, require_peft
|
||||
from .testing_utils import TrlTestCase, require_peft, require_torch_gpu_if_bnb_not_multi_backend_enabled
|
||||
|
||||
|
||||
if is_peft_available():
|
||||
|
@ -1212,47 +1212,6 @@ class TestRLOOTrainer(TrlTestCase):
|
||||
elif "base_layer" not in n: # We expect the peft params to be different (except for the base layer)
|
||||
assert not torch.allclose(param, new_param), f"Parameter {n} has not changed."
|
||||
|
||||
@require_vision
|
||||
def test_training_vlm_and_prompt_truncation(self):
|
||||
# If not handled properly, prompt truncation may truncate image token
|
||||
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train")
|
||||
|
||||
def reward_func(completions, **kwargs):
|
||||
"""Reward function that rewards longer completions."""
|
||||
return [float(len(completion[0]["content"])) for completion in completions]
|
||||
|
||||
training_args = RLOOConfig(
|
||||
output_dir=self.tmp_dir,
|
||||
learning_rate=0.1, # increase the learning rate to speed up the test
|
||||
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
|
||||
num_generations=3, # reduce the number of generations to reduce memory usage
|
||||
max_completion_length=8, # reduce the completion length to reduce memory usage
|
||||
max_prompt_length=18,
|
||||
report_to="none",
|
||||
)
|
||||
trainer = RLOOTrainer(
|
||||
model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",
|
||||
reward_funcs=reward_func,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
)
|
||||
|
||||
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
||||
|
||||
trainer.train()
|
||||
|
||||
assert trainer.state.log_history[-1]["train_loss"] is not None
|
||||
|
||||
# Check that the params have changed
|
||||
# Because of the way the tiny models are initialized, the gradient does not flow properly through the
|
||||
# vision parts of the model, so we skip them. Ideally, we should fix the init of these models.
|
||||
params_to_skip = ("model.visual.",)
|
||||
for n, param in previous_trainable_params.items():
|
||||
if n.startswith(params_to_skip):
|
||||
continue
|
||||
new_param = trainer.model.get_parameter(n)
|
||||
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
("trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",),
|
||||
|
@ -32,7 +32,15 @@ from .testing_utils import TrlTestCase, ignore_warnings, require_bitsandbytes, r
|
||||
|
||||
|
||||
if is_peft_available():
|
||||
from peft import LoraConfig, PeftModel, PromptEncoderConfig, TaskType, get_peft_model
|
||||
from peft import (
|
||||
LoraConfig,
|
||||
PeftModel,
|
||||
PrefixTuningConfig,
|
||||
PromptEncoderConfig,
|
||||
PromptTuningConfig,
|
||||
TaskType,
|
||||
get_peft_model,
|
||||
)
|
||||
|
||||
|
||||
class TestDFTLoss(TrlTestCase):
|
||||
@ -453,7 +461,7 @@ class TestSFTTrainer(TrlTestCase):
|
||||
assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
|
||||
|
||||
@require_peft
|
||||
def test_train_dense_with_peft_config(self):
|
||||
def test_train_dense_with_peft_config_lora(self):
|
||||
# Get the base model parameter names
|
||||
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||
@ -489,6 +497,66 @@ class TestSFTTrainer(TrlTestCase):
|
||||
elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer)
|
||||
assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
("prompt_tuning",),
|
||||
("prefix_tuning",),
|
||||
("prompt_encoder",),
|
||||
]
|
||||
)
|
||||
@require_peft
|
||||
def test_train_with_peft_config_prompt_tuning(self, peft_type):
|
||||
# Get the base model parameter names
|
||||
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||
base_param_names = [f"base_model.{n}" for n, _ in model.named_parameters()]
|
||||
|
||||
# Get the dataset
|
||||
dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train")
|
||||
|
||||
# Initialize the trainer, p-tuning doesn't support gradient checkpointing
|
||||
training_args = SFTConfig(bf16=False, output_dir=self.tmp_dir, report_to="none", gradient_checkpointing=False)
|
||||
if peft_type == "prompt_tuning":
|
||||
peft_config = PromptTuningConfig(
|
||||
task_type=TaskType.CAUSAL_LM,
|
||||
num_virtual_tokens=4,
|
||||
tokenizer_name_or_path="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
|
||||
)
|
||||
elif peft_type == "prefix_tuning":
|
||||
peft_config = PrefixTuningConfig(
|
||||
task_type=TaskType.CAUSAL_LM,
|
||||
num_virtual_tokens=4,
|
||||
)
|
||||
elif peft_type == "prompt_encoder":
|
||||
peft_config = PromptEncoderConfig(
|
||||
task_type=TaskType.CAUSAL_LM,
|
||||
num_virtual_tokens=4,
|
||||
encoder_hidden_size=model.config.hidden_size, # This will be overwritten below
|
||||
)
|
||||
trainer = SFTTrainer(
|
||||
model=model_id,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
peft_config=peft_config,
|
||||
)
|
||||
|
||||
# Save the initial parameters to compare them later
|
||||
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
||||
|
||||
# Train the model
|
||||
trainer.train()
|
||||
|
||||
# Check that the training loss is not None
|
||||
assert trainer.state.log_history[-1]["train_loss"] is not None
|
||||
|
||||
# Check the peft params have changed and the base model params have not changed
|
||||
for n, param in previous_trainable_params.items():
|
||||
new_param = trainer.model.get_parameter(n)
|
||||
if n in base_param_names: # We expect the base model parameters to be the same
|
||||
assert torch.allclose(param, new_param), f"Parameter {n} has changed"
|
||||
else: # We expect the peft parameters to be different
|
||||
assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
|
||||
|
||||
@require_peft
|
||||
def test_train_moe_with_peft_config(self):
|
||||
# Get the base model parameter names
|
||||
@ -1373,6 +1441,38 @@ class TestSFTTrainer(TrlTestCase):
|
||||
new_param = trainer.model.get_parameter(n)
|
||||
assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated"
|
||||
|
||||
# Special case for Gemma, as it uses token_type_ids, and we need to ensure they are properly in the collator.
|
||||
@require_vision
|
||||
def test_train_vlm_prompt_completion_gemma(self):
|
||||
# Get the dataset
|
||||
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_completion", split="train")
|
||||
|
||||
# Initialize the trainer
|
||||
training_args = SFTConfig(
|
||||
output_dir=self.tmp_dir,
|
||||
max_length=None, # For VLMs, truncating can remove image tokens, leading to errors
|
||||
report_to="none",
|
||||
)
|
||||
trainer = SFTTrainer(
|
||||
model="trl-internal-testing/tiny-Gemma3ForConditionalGeneration",
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
)
|
||||
|
||||
# Save the initial parameters to compare them later
|
||||
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
||||
|
||||
# Train the model
|
||||
trainer.train()
|
||||
|
||||
# Check that the training loss is not None
|
||||
assert trainer.state.log_history[-1]["train_loss"] is not None
|
||||
|
||||
# Check the params have changed
|
||||
for n, param in previous_trainable_params.items():
|
||||
new_param = trainer.model.get_parameter(n)
|
||||
assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated"
|
||||
|
||||
# Gemma 3n uses a timm encoder, making it difficult to create a smaller variant for testing.
|
||||
# To ensure coverage, we run tests on the full model but mark them as slow to exclude from default runs.
|
||||
@pytest.mark.slow
|
||||
|
@ -42,7 +42,6 @@ from trl.trainer.utils import (
|
||||
shuffle_sequence_dict,
|
||||
split_pixel_values_by_grid,
|
||||
split_tensor_dict,
|
||||
truncate_with_protected_tokens,
|
||||
unsplit_pixel_values_by_grid,
|
||||
)
|
||||
|
||||
@ -1009,84 +1008,6 @@ class TestSplitPixelValuesByGrid(TrlTestCase):
|
||||
assert torch.equal(result["image_grid_thw"][1], torch.tensor([[1, 2, 2], [1, 2, 1]]))
|
||||
|
||||
|
||||
class TestTruncateWithProtectedTokens(TrlTestCase):
|
||||
def test_basic_example(self):
|
||||
"""Test the basic example from the problem description."""
|
||||
prompt_ids = [1, 2, 3, 4, 5]
|
||||
protected_tokens = [2, 3]
|
||||
target_length = 3
|
||||
|
||||
new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens)
|
||||
|
||||
expected_ids = [2, 3, 5]
|
||||
assert new_ids == expected_ids
|
||||
|
||||
def test_no_truncation_needed(self):
|
||||
"""Test when target length equals current length."""
|
||||
prompt_ids = [1, 2, 3]
|
||||
protected_tokens = [2]
|
||||
target_length = 3
|
||||
|
||||
new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens)
|
||||
|
||||
assert new_ids == prompt_ids
|
||||
|
||||
def test_no_protected_tokens(self):
|
||||
"""Test truncation with no protected tokens (normal right truncation)."""
|
||||
prompt_ids = [1, 2, 3, 4, 5]
|
||||
protected_tokens = []
|
||||
target_length = 3
|
||||
|
||||
new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens)
|
||||
|
||||
expected_ids = [3, 4, 5] # Last 3 tokens
|
||||
assert new_ids == expected_ids
|
||||
|
||||
def test_all_tokens_protected(self):
|
||||
"""Test when all remaining tokens are protected."""
|
||||
prompt_ids = [1, 2, 3, 4, 5]
|
||||
protected_tokens = [3, 4, 5]
|
||||
target_length = 3
|
||||
|
||||
new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens)
|
||||
|
||||
expected_ids = [3, 4, 5]
|
||||
assert new_ids == expected_ids
|
||||
|
||||
def test_too_many_protected_tokens(self):
|
||||
"""Test error when too many protected tokens for target length."""
|
||||
prompt_ids = [1, 2, 3, 4, 5]
|
||||
protected_tokens = [1, 2, 3, 4]
|
||||
target_length = 3
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens)
|
||||
|
||||
def test_single_batch_single_token(self):
|
||||
"""Test edge case with single batch and single token."""
|
||||
prompt_ids = [5]
|
||||
protected_tokens = [5]
|
||||
target_length = 1
|
||||
|
||||
new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens)
|
||||
|
||||
assert new_ids == prompt_ids
|
||||
|
||||
def test_order_preservation(self):
|
||||
"""Test that relative order is preserved."""
|
||||
prompt_ids = [10, 2, 20, 3, 30, 40]
|
||||
protected_tokens = [2, 3]
|
||||
target_length = 4
|
||||
|
||||
new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens)
|
||||
|
||||
# Should keep protected tokens 2, 3 and last 2 non-protected tokens 30, 40
|
||||
# Order should be: 2, 3, 30, 40 (maintaining original relative positions)
|
||||
expected_ids = [2, 3, 30, 40]
|
||||
|
||||
assert new_ids == expected_ids
|
||||
|
||||
|
||||
class TestUnsplitPixelValuesByGrid(TrlTestCase):
|
||||
def test_unsplit_correctly(self):
|
||||
pixel_values = [torch.randn(4, 5), torch.randn(2, 5)]
|
||||
|
@ -16,6 +16,7 @@ import functools
|
||||
import random
|
||||
import signal
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
|
||||
import psutil
|
||||
import pytest
|
||||
@ -46,6 +47,21 @@ require_3_accelerators = pytest.mark.skipif(
|
||||
)
|
||||
|
||||
|
||||
def is_bitsandbytes_multi_backend_available() -> bool:
|
||||
if is_bitsandbytes_available():
|
||||
import bitsandbytes as bnb
|
||||
|
||||
return "multi_backend" in getattr(bnb, "features", set())
|
||||
return False
|
||||
|
||||
|
||||
# Function ported from transformers.testing_utils before transformers#41283
|
||||
require_torch_gpu_if_bnb_not_multi_backend_enabled = pytest.mark.skipif(
|
||||
not is_bitsandbytes_multi_backend_available() and not torch_device == "cuda",
|
||||
reason="test requires bitsandbytes multi-backend enabled or 'cuda' torch device",
|
||||
)
|
||||
|
||||
|
||||
class RandomBinaryJudge(BaseBinaryJudge):
|
||||
"""
|
||||
Random binary judge, for testing purposes.
|
||||
@ -73,7 +89,7 @@ class TrlTestCase:
|
||||
self.tmp_dir = str(tmp_path)
|
||||
|
||||
|
||||
def ignore_warnings(message: str = None, category: type[Warning] = Warning) -> callable:
|
||||
def ignore_warnings(message: str = None, category: type[Warning] = Warning) -> Callable:
|
||||
"""
|
||||
Decorator to ignore warnings with a specific message and/or category.
|
||||
|
||||
|
@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .import_utils import _LazyModule
|
||||
@ -40,7 +39,6 @@ _import_structure = {
|
||||
"truncate_dataset",
|
||||
"unpair_preference_dataset",
|
||||
],
|
||||
"extras": ["BestOfNSampler"],
|
||||
"models": [
|
||||
"SUPPORTED_ARCHITECTURES",
|
||||
"AutoModelForCausalLMWithValueHead",
|
||||
@ -120,7 +118,6 @@ if TYPE_CHECKING:
|
||||
truncate_dataset,
|
||||
unpair_preference_dataset,
|
||||
)
|
||||
from .extras import BestOfNSampler
|
||||
from .models import (
|
||||
SUPPORTED_ARCHITECTURES,
|
||||
AutoModelForCausalLMWithValueHead,
|
||||
|
@ -143,7 +143,13 @@ def apply_chat_template(
|
||||
|
||||
# Apply the chat template to the whole conversation
|
||||
if "messages" in example:
|
||||
messages = tokenizer.apply_chat_template(example["messages"], tools=tools, tokenize=False, **template_kwargs)
|
||||
messages = tokenizer.apply_chat_template(
|
||||
example["messages"],
|
||||
tools=tools,
|
||||
tokenize=False,
|
||||
**example.get("chat_template_kwargs", {}),
|
||||
**template_kwargs,
|
||||
)
|
||||
|
||||
# Apply the chat template to the prompt, adding the generation prompt
|
||||
if "prompt" in example:
|
||||
@ -162,6 +168,7 @@ def apply_chat_template(
|
||||
continue_final_message=continue_final_message,
|
||||
tokenize=False,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
**example.get("chat_template_kwargs", {}),
|
||||
**template_kwargs,
|
||||
)
|
||||
|
||||
@ -169,7 +176,11 @@ def apply_chat_template(
|
||||
if "prompt" in example: # explicit prompt and prompt-completion case
|
||||
if "chosen" in example:
|
||||
prompt_chosen = tokenizer.apply_chat_template(
|
||||
example["prompt"] + example["chosen"], tools=tools, tokenize=False, **template_kwargs
|
||||
example["prompt"] + example["chosen"],
|
||||
tools=tools,
|
||||
tokenize=False,
|
||||
**example.get("chat_template_kwargs", {}),
|
||||
**template_kwargs,
|
||||
)
|
||||
# DeepSeek-R1 inserts a <tool_call> token when using `add_generation_prompt`, which can cause discrepancies
|
||||
# between the prompt alone and the combined prompt+completion. To ensure consistency, we extract the
|
||||
@ -179,24 +190,42 @@ def apply_chat_template(
|
||||
chosen = prompt_chosen[len(prompt) :]
|
||||
if "rejected" in example and "prompt" in example: # explicit prompt
|
||||
prompt_rejected = tokenizer.apply_chat_template(
|
||||
example["prompt"] + example["rejected"], tools=tools, tokenize=False, **template_kwargs
|
||||
example["prompt"] + example["rejected"],
|
||||
tools=tools,
|
||||
tokenize=False,
|
||||
**example.get("chat_template_kwargs", {}),
|
||||
**template_kwargs,
|
||||
)
|
||||
# Handle DeepSeek-R1 <tool_call> token, see the above comment for details
|
||||
prompt = "".join(x for x, _ in takewhile(lambda x: x[0] == x[1], zip(prompt, prompt_rejected)))
|
||||
rejected = prompt_rejected[len(prompt) :]
|
||||
if "completion" in example:
|
||||
prompt_completion = tokenizer.apply_chat_template(
|
||||
example["prompt"] + example["completion"], tools=tools, tokenize=False, **template_kwargs
|
||||
example["prompt"] + example["completion"],
|
||||
tools=tools,
|
||||
tokenize=False,
|
||||
**example.get("chat_template_kwargs", {}),
|
||||
**template_kwargs,
|
||||
)
|
||||
# Handle DeepSeek-R1 <tool_call> token, see the above comment for details
|
||||
prompt = "".join(x for x, _ in takewhile(lambda x: x[0] == x[1], zip(prompt, prompt_completion)))
|
||||
completion = prompt_completion[len(prompt) :]
|
||||
else: # implicit prompt case
|
||||
if "chosen" in example:
|
||||
chosen = tokenizer.apply_chat_template(example["chosen"], tools=tools, tokenize=False, **template_kwargs)
|
||||
chosen = tokenizer.apply_chat_template(
|
||||
example["chosen"],
|
||||
tools=tools,
|
||||
tokenize=False,
|
||||
**example.get("chat_template_kwargs", {}),
|
||||
**template_kwargs,
|
||||
)
|
||||
if "rejected" in example:
|
||||
rejected = tokenizer.apply_chat_template(
|
||||
example["rejected"], tools=tools, tokenize=False, **template_kwargs
|
||||
example["rejected"],
|
||||
tools=tools,
|
||||
tokenize=False,
|
||||
**example.get("chat_template_kwargs", {}),
|
||||
**template_kwargs,
|
||||
)
|
||||
|
||||
# Extract the completion by removing the prompt part from the prompt-completion string
|
||||
@ -239,7 +268,9 @@ def maybe_apply_chat_template(
|
||||
- Unpaired preference dataset: `"prompt"`, `"completion"`, and `"label"`.
|
||||
|
||||
For keys `"messages"`, `"prompt"`, `"chosen"`, `"rejected"`, and `"completion"`, the values are lists of
|
||||
messages, where each message is a dictionary with keys `"role"` and `"content"`.
|
||||
messages, where each message is a dictionary with keys `"role"` and `"content"`. Additionally, the example
|
||||
may contain a `"chat_template_kwargs"` key, which is a dictionary of additional keyword arguments to pass
|
||||
to the chat template renderer.
|
||||
tokenizer (`PreTrainedTokenizerBase`):
|
||||
Tokenizer to apply the chat template with.
|
||||
tools (`list[Union[dict, Callable]]`, *optional*):
|
||||
|
@ -17,12 +17,10 @@ from typing import TYPE_CHECKING
|
||||
from ..import_utils import _LazyModule
|
||||
|
||||
|
||||
_import_structure = {
|
||||
"best_of_n_sampler": ["BestOfNSampler"],
|
||||
}
|
||||
_import_structure = {}
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .best_of_n_sampler import BestOfNSampler
|
||||
pass
|
||||
else:
|
||||
import sys
|
||||
|
||||
|
@ -1,132 +0,0 @@
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers import GenerationConfig, PreTrainedTokenizer, PreTrainedTokenizerFast, set_seed
|
||||
|
||||
from ..models import SUPPORTED_ARCHITECTURES, PreTrainedModelWrapper
|
||||
|
||||
|
||||
class BestOfNSampler:
|
||||
"""
|
||||
Sampler for best-of-n generation.
|
||||
|
||||
Args:
|
||||
model ([`PreTrainedModelWrapper`]):
|
||||
The pretrained model to use for generation.
|
||||
tokenizer ([`~transformers.PreTrainedTokenizer`] or [`~transformers.PreTrainedTokenizerFast`]):
|
||||
Tokenizer associated with the pretrained model.
|
||||
queries_to_scores (`Callable[[list[str]], list[float]]`):
|
||||
Callable that takes a list of generated texts and returns the associated reward scores.
|
||||
length_sampler (`Any`):
|
||||
Sampler used to sample the length of the generated text.
|
||||
sample_size (`int`, *optional*, defaults to `4`):
|
||||
Number of samples to generate for each query.
|
||||
seed (`int`, *optional*):
|
||||
Random seed used to control generation.
|
||||
n_candidates (`int`, *optional*, defaults to `1`):
|
||||
Number of candidates to return for each query.
|
||||
generation_config ([`~transformers.GenerationConfig`], *optional*):
|
||||
Generation config passed to the underlying model's `generate` method. See
|
||||
[`~transformers.GenerationConfig`] for more details.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: PreTrainedModelWrapper,
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||
queries_to_scores: Callable[[list[str]], list[float]],
|
||||
length_sampler: Any,
|
||||
sample_size: int = 4,
|
||||
seed: Optional[int] = None,
|
||||
n_candidates: int = 1,
|
||||
generation_config: Optional[GenerationConfig] = None,
|
||||
) -> None:
|
||||
if seed is not None:
|
||||
set_seed(seed)
|
||||
|
||||
if not isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)):
|
||||
raise ValueError(
|
||||
f"tokenizer must be a PreTrainedTokenizer or PreTrainedTokenizerFast, got {type(tokenizer)}"
|
||||
)
|
||||
if not isinstance(model, (SUPPORTED_ARCHITECTURES)):
|
||||
raise ValueError(
|
||||
f"model must be a PreTrainedModelWrapper, got {type(model)} - supported architectures are: {SUPPORTED_ARCHITECTURES}"
|
||||
)
|
||||
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
self.queries_to_scores = queries_to_scores
|
||||
self.length_sampler = length_sampler
|
||||
self.gen_config = generation_config
|
||||
self.sample_size = sample_size
|
||||
self.n_candidates = n_candidates
|
||||
|
||||
def generate(
|
||||
self,
|
||||
tokenized_query: Union[list[int], torch.Tensor, list[torch.Tensor], list[list[int]]],
|
||||
skip_special_tokens: bool = True,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
**generation_kwargs,
|
||||
) -> list[list[str]]:
|
||||
"""
|
||||
Generate the best of n samples for input queries.
|
||||
|
||||
Args:
|
||||
tokenized_query (`list[int]` or `torch.Tensor` or `list[torch.Tensor]` or `list[list[int]]`):
|
||||
Either a single tokenized query (a single tensor or a list of integers) or a batch of tokenized queries
|
||||
(a list of tensors or a list of lists of integers).
|
||||
skip_special_tokens (`bool`, *optional*, defaults to `True`):
|
||||
Whether to remove the special tokens from the output.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device on which the model will be loaded.
|
||||
**generation_kwargs:
|
||||
Additional keyword arguments passed along to the underlying model's `generate` method. This is used to
|
||||
override generation config.
|
||||
|
||||
Returns:
|
||||
`list[list[str]]`: A list of lists of generated texts.
|
||||
"""
|
||||
queries = None
|
||||
|
||||
if isinstance(tokenized_query, torch.Tensor) and tokenized_query.ndim == 1:
|
||||
queries = tokenized_query.unsqueeze(0)
|
||||
elif isinstance(tokenized_query, list):
|
||||
element_type = type(tokenized_query[0])
|
||||
if element_type is int:
|
||||
queries = torch.tensor(tokenized_query).unsqueeze(0)
|
||||
elif element_type is torch.Tensor:
|
||||
queries = [tensor.reshape((1, -1)) for tensor in tokenized_query]
|
||||
else:
|
||||
queries = [torch.tensor(query).reshape((1, -1)) for query in tokenized_query]
|
||||
|
||||
result = []
|
||||
|
||||
for query in queries:
|
||||
queries = query.repeat((self.sample_size, 1))
|
||||
output = self.model.generate(
|
||||
queries.to(device),
|
||||
max_new_tokens=self.length_sampler(),
|
||||
generation_config=self.gen_config,
|
||||
**generation_kwargs,
|
||||
).squeeze()
|
||||
output = self.tokenizer.batch_decode(output, skip_special_tokens=skip_special_tokens)
|
||||
scores = torch.tensor(self.queries_to_scores(output))
|
||||
output = [output[i] for i in scores.topk(self.n_candidates).indices]
|
||||
result.append(output)
|
||||
|
||||
return result
|
@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
from typing import Callable, Literal, Optional
|
||||
|
||||
import datasets
|
||||
@ -41,7 +42,17 @@ def conversations_formatting_function(
|
||||
r"""
|
||||
return a callable function that takes in a "messages" dataset and returns a formatted dataset, based on the
|
||||
tokenizer apply chat template to the dataset along with the schema of the list of functions in the tools list.
|
||||
|
||||
.. deprecated:: 0.24.0
|
||||
`conversations_formatting_function` is deprecated and will be removed in version 0.27.
|
||||
Please use `tokenizer.apply_chat_template()` directly instead.
|
||||
"""
|
||||
warnings.warn(
|
||||
"`conversations_formatting_function` is deprecated and will be removed in TRL 0.27. "
|
||||
"Please use `tokenizer.apply_chat_template()` directly instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
def format_dataset(examples):
|
||||
if isinstance(examples[messages_field][0], list):
|
||||
@ -61,7 +72,17 @@ def instructions_formatting_function(tokenizer: AutoTokenizer):
|
||||
r"""
|
||||
return a callable function that takes in an "instructions" dataset and returns a formatted dataset, based on the
|
||||
tokenizer apply chat template to the dataset
|
||||
|
||||
.. deprecated:: 0.24.0
|
||||
`instructions_formatting_function` is deprecated and will be removed in version 0.27.
|
||||
Please use `tokenizer.apply_chat_template()` directly instead.
|
||||
"""
|
||||
warnings.warn(
|
||||
"`instructions_formatting_function` is deprecated and will be removed in TRL 0.27. "
|
||||
"Please use `tokenizer.apply_chat_template()` directly instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
def format_dataset(examples):
|
||||
if isinstance(examples["prompt"], list):
|
||||
@ -99,7 +120,18 @@ def get_formatting_func_from_dataset(
|
||||
|
||||
Returns:
|
||||
Callable: Formatting function if the dataset format is supported else None
|
||||
|
||||
.. deprecated:: 0.24.0
|
||||
`get_formatting_func_from_dataset` is deprecated and will be removed in version 0.27.
|
||||
Please use `tokenizer.apply_chat_template()` directly instead.
|
||||
"""
|
||||
warnings.warn(
|
||||
"`get_formatting_func_from_dataset` is deprecated and will be removed in TRL 0.27. "
|
||||
"Please use `tokenizer.apply_chat_template()` directly instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
if isinstance(dataset, Dataset):
|
||||
if "messages" in dataset.features:
|
||||
if dataset.features["messages"] == FORMAT_MAPPING["chatml"]:
|
||||
|
@ -15,7 +15,7 @@
|
||||
import contextlib
|
||||
import functools
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Callable, Generator
|
||||
|
||||
from transformers import Trainer
|
||||
from transformers.integrations import is_mlflow_available, is_wandb_available
|
||||
@ -68,12 +68,12 @@ def profiling_context(trainer: Trainer, name: str) -> Generator[None, None, None
|
||||
mlflow.log_metrics(profiling_metrics, step=trainer.state.global_step)
|
||||
|
||||
|
||||
def profiling_decorator(func: callable) -> callable:
|
||||
def profiling_decorator(func: Callable) -> Callable:
|
||||
"""
|
||||
Decorator to profile a function and log execution time using [`extras.profiling.profiling_context`].
|
||||
|
||||
Args:
|
||||
func (`callable`):
|
||||
func (`Callable`):
|
||||
Function to be profiled.
|
||||
|
||||
Example:
|
||||
|
@ -182,6 +182,7 @@ class VLLMClient:
|
||||
top_k: int = -1,
|
||||
min_p: float = 0.0,
|
||||
max_tokens: int = 16,
|
||||
truncate_prompt_tokens: Optional[int] = None,
|
||||
guided_decoding_regex: Optional[str] = None,
|
||||
generation_kwargs: Optional[dict] = None,
|
||||
) -> list[list[int]]:
|
||||
@ -207,6 +208,10 @@ class VLLMClient:
|
||||
Minimum probability for sampling.
|
||||
max_tokens (`int`, *optional*, defaults to `16`):
|
||||
Maximum number of tokens to generate for each prompt.
|
||||
truncate_prompt_tokens (`int`, *optional*):
|
||||
If set to `-1`, will use the truncation size supported by the model. If set to an integer k, will use
|
||||
only the last k tokens from the prompt (i.e., left truncation). If set to `None`, truncation is
|
||||
disabled.
|
||||
guided_decoding_regex (`str`, *optional*):
|
||||
Regular expression to guide the decoding process.
|
||||
generation_kwargs (`dict`, *optional*):
|
||||
@ -246,6 +251,7 @@ class VLLMClient:
|
||||
"top_k": top_k,
|
||||
"min_p": min_p,
|
||||
"max_tokens": max_tokens,
|
||||
"truncate_prompt_tokens": truncate_prompt_tokens,
|
||||
"guided_decoding_regex": guided_decoding_regex,
|
||||
"generation_kwargs": generation_kwargs or {},
|
||||
},
|
||||
|
@ -219,7 +219,6 @@ class OffloadActivations(saved_tensors_hooks):
|
||||
verify_sufficient_virtual_memory()
|
||||
|
||||
self.is_first_backward_call = False
|
||||
self.is_first_forward_call = True
|
||||
|
||||
if unpack_tensor_id not in self.tracker:
|
||||
raise ValueError(f"Untracked tensor with id {unpack_tensor_id}")
|
||||
@ -231,6 +230,9 @@ class OffloadActivations(saved_tensors_hooks):
|
||||
|
||||
# clear tensor from tracking
|
||||
del self.tracker[unpack_tensor_id]
|
||||
# Only set is_first_forward_call to True when all tensors have been unpacked
|
||||
if len(self.tracker) == 0:
|
||||
self.is_first_forward_call = True
|
||||
return maybe_accelerator_tensor
|
||||
|
||||
def unpack_tensor_with_streams(unpack_tensor_id: int) -> torch.Tensor:
|
||||
@ -254,7 +256,6 @@ class OffloadActivations(saved_tensors_hooks):
|
||||
verify_sufficient_virtual_memory()
|
||||
|
||||
self.is_first_backward_call = False
|
||||
self.is_first_forward_call = True
|
||||
|
||||
if unpack_tensor_id not in self.tracker:
|
||||
raise ValueError(f"untracked tensor with id {unpack_tensor_id}")
|
||||
@ -359,6 +360,9 @@ class OffloadActivations(saved_tensors_hooks):
|
||||
|
||||
# clear tensor from tracking
|
||||
del self.tracker[unpack_tensor_id]
|
||||
# Only set is_first_forward_call to True when all tensors have been unpacked
|
||||
if len(self.tracker) == 0:
|
||||
self.is_first_forward_call = True
|
||||
return maybe_accelerator_tensor
|
||||
|
||||
unpack_tensor = unpack_tensor_with_streams if self.use_streams else unpack_tensor_single_stream
|
||||
|
@ -495,6 +495,7 @@ def main(script_args: ScriptArguments):
|
||||
top_k: int = -1
|
||||
min_p: float = 0.0
|
||||
max_tokens: int = 16
|
||||
truncate_prompt_tokens: Optional[int] = None
|
||||
guided_decoding_regex: Optional[str] = None
|
||||
generation_kwargs: dict = field(default_factory=dict)
|
||||
|
||||
@ -525,6 +526,9 @@ def main(script_args: ScriptArguments):
|
||||
- `min_p` (`float`, *optional*, defaults to `0.0`): Minimum probability threshold for sampling.
|
||||
- `max_tokens` (`int`, *optional*, defaults to `16`): Maximum number of tokens to generate for each
|
||||
completion.
|
||||
- `truncate_prompt_tokens` (`int`, *optional*): If set to `-1`, will use the truncation size supported
|
||||
by the model. If set to an integer k, will use only the last k tokens from the prompt (i.e., left
|
||||
truncation). If set to `None`, truncation is disabled.
|
||||
- `guided_decoding_regex` (`str`, *optional*): A regex pattern for guided decoding. If provided, the
|
||||
model will only generate tokens that match this regex pattern.
|
||||
- `generation_kwargs` (`dict`, *optional*): Additional generation parameters to pass to the vLLM
|
||||
@ -575,6 +579,7 @@ def main(script_args: ScriptArguments):
|
||||
"top_k": request.top_k,
|
||||
"min_p": request.min_p,
|
||||
"max_tokens": request.max_tokens,
|
||||
"truncate_prompt_tokens": request.truncate_prompt_tokens,
|
||||
"guided_decoding": guided_decoding,
|
||||
"logprobs": 0,
|
||||
}
|
||||
|
@ -16,6 +16,7 @@ import inspect
|
||||
import os
|
||||
import random
|
||||
import textwrap
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from operator import itemgetter
|
||||
@ -360,6 +361,13 @@ class BCOTrainer(BaseTrainer):
|
||||
embedding_func: Optional[Callable] = None,
|
||||
embedding_tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
||||
):
|
||||
if not os.environ.get("TRL_EXPERIMENTAL_SILENCE"):
|
||||
warnings.warn(
|
||||
"This trainer will soon be moved to trl.experimental and is a candidate for removal. If you rely on "
|
||||
"it and want it to remain, please share your comments here: "
|
||||
"https://github.com/huggingface/trl/issues/4223. Silence this warning by setting environment variable "
|
||||
"TRL_EXPERIMENTAL_SILENCE=1."
|
||||
)
|
||||
if embedding_func is not None and not (is_sklearn_available() and is_joblib_available()):
|
||||
raise ImportError(
|
||||
"BCOTrainer with UDM requires the scikit-learn and joblib libraries. Please install it with `pip install scikit-learn joblib`."
|
||||
|
@ -14,6 +14,7 @@
|
||||
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from typing import Optional, Union
|
||||
|
||||
import pandas as pd
|
||||
@ -583,7 +584,7 @@ class WeaveCallback(TrainerCallback):
|
||||
self,
|
||||
trainer: Trainer,
|
||||
project_name: Optional[str] = None,
|
||||
scorers: Optional[dict[str, callable]] = None,
|
||||
scorers: Optional[dict[str, Callable]] = None,
|
||||
generation_config: Optional[GenerationConfig] = None,
|
||||
num_prompts: Optional[int] = None,
|
||||
dataset_name: str = "eval_dataset",
|
||||
|
@ -13,8 +13,10 @@
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import os
|
||||
import random
|
||||
import textwrap
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
@ -142,6 +144,13 @@ class CPOTrainer(BaseTrainer):
|
||||
peft_config: Optional[dict] = None,
|
||||
compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
|
||||
):
|
||||
if not os.environ.get("TRL_EXPERIMENTAL_SILENCE"):
|
||||
warnings.warn(
|
||||
"This trainer will soon be moved to trl.experimental and is a candidate for removal. If you rely on "
|
||||
"it and want it to remain, please share your comments here: "
|
||||
"https://github.com/huggingface/trl/issues/4223. Silence this warning by setting environment variable "
|
||||
"TRL_EXPERIMENTAL_SILENCE=1."
|
||||
)
|
||||
if args.model_init_kwargs is None:
|
||||
model_init_kwargs = {}
|
||||
elif not isinstance(model, str):
|
||||
|
@ -12,8 +12,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import random
|
||||
import textwrap
|
||||
import warnings
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
@ -38,11 +40,7 @@ from ..models import prepare_deepspeed
|
||||
from ..models.utils import unwrap_model_for_generation
|
||||
from .gkd_config import GKDConfig
|
||||
from .sft_trainer import SFTTrainer
|
||||
from .utils import (
|
||||
DataCollatorForChatML,
|
||||
disable_dropout_in_model,
|
||||
empty_cache,
|
||||
)
|
||||
from .utils import DataCollatorForChatML, disable_dropout_in_model, empty_cache
|
||||
|
||||
|
||||
if is_peft_available():
|
||||
@ -127,6 +125,13 @@ class GKDTrainer(SFTTrainer):
|
||||
peft_config: Optional["PeftConfig"] = None,
|
||||
formatting_func: Optional[Callable] = None,
|
||||
):
|
||||
if not os.environ.get("TRL_EXPERIMENTAL_SILENCE"):
|
||||
warnings.warn(
|
||||
"This trainer will soon be moved to trl.experimental and is a candidate for removal. If you rely on "
|
||||
"it and want it to remain, please share your comments here: "
|
||||
"https://github.com/huggingface/trl/issues/4223. Silence this warning by setting environment variable "
|
||||
"TRL_EXPERIMENTAL_SILENCE=1."
|
||||
)
|
||||
# Ensure Trainer does not drop non-signature columns used by the collator (e.g., "prompts")
|
||||
args.remove_unused_columns = False
|
||||
# Respect a user-provided data_collator; otherwise, provide a ChatML collator that
|
||||
|
@ -14,7 +14,6 @@
|
||||
|
||||
import inspect
|
||||
import os
|
||||
import re
|
||||
import textwrap
|
||||
from collections import defaultdict, deque
|
||||
from contextlib import nullcontext
|
||||
@ -71,7 +70,6 @@ from .utils import (
|
||||
shuffle_sequence_dict,
|
||||
split_pixel_values_by_grid,
|
||||
split_tensor_dict,
|
||||
truncate_with_protected_tokens,
|
||||
unsplit_pixel_values_by_grid,
|
||||
)
|
||||
|
||||
@ -275,7 +273,7 @@ class GRPOTrainer(BaseTrainer):
|
||||
|
||||
# Processing class
|
||||
if processing_class is None:
|
||||
processing_class = AutoProcessor.from_pretrained(model.config._name_or_path)
|
||||
processing_class = AutoProcessor.from_pretrained(model.config._name_or_path, truncation_side="left")
|
||||
|
||||
# Handle pad token for processors or tokenizers
|
||||
if isinstance(processing_class, ProcessorMixin):
|
||||
@ -291,10 +289,6 @@ class GRPOTrainer(BaseTrainer):
|
||||
self.pad_token = tokenizer.pad_token
|
||||
self.pad_token_id = tokenizer.pad_token_id
|
||||
self.eos_token_id = tokenizer.eos_token_id
|
||||
self.image_token = getattr(processing_class, "image_token", None)
|
||||
self.image_token_id = getattr(processing_class, "image_token_id", None)
|
||||
self.vision_start_token_id = getattr(model.config, "vision_start_token_id", None)
|
||||
self.vision_end_token_id = getattr(model.config, "vision_end_token_id", None)
|
||||
|
||||
# Reward functions
|
||||
if not isinstance(reward_funcs, list):
|
||||
@ -1092,58 +1086,12 @@ class GRPOTrainer(BaseTrainer):
|
||||
maybe_apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts
|
||||
]
|
||||
|
||||
prompt_inputs = self.processing_class(
|
||||
text=prompts_text,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
padding_side="left",
|
||||
add_special_tokens=False,
|
||||
**kwargs,
|
||||
)
|
||||
prompt_inputs = super()._prepare_inputs(prompt_inputs)
|
||||
forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]}
|
||||
|
||||
if self.max_prompt_length is not None:
|
||||
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
|
||||
prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool())]
|
||||
|
||||
# If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens.
|
||||
# Then we decode those tokens back into text. We set `skip_special_tokens=False` because some special
|
||||
# tokens are needed for generation.
|
||||
protected = [self.image_token_id, self.vision_start_token_id, self.vision_end_token_id]
|
||||
protected = [token for token in protected if token is not None]
|
||||
prompt_ids = [truncate_with_protected_tokens(ids, self.max_prompt_length, protected) for ids in prompt_ids]
|
||||
|
||||
prompts_text = self.processing_class.batch_decode(
|
||||
prompt_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False
|
||||
)
|
||||
|
||||
# The chat template sometimes inserts a single image token into the prompt text. However, when this text is
|
||||
# later tokenized, the single image token string is expanded into multiple image token IDs, depending on the
|
||||
# image size. Since we're detokenizing here, we may see repeated image tokens in the decoded text. We
|
||||
# collapse them back into a single token string to match the original chat template in case it originally
|
||||
# applies it. Otherwise, it assumes that the chat template uses only vision_start_token_id to indicate images
|
||||
# (e.g. Gemma 3) and removes all image_token instances and vision_end_token_id as well, leaving only
|
||||
# the vision_start_token_id (e.g. <start_of_image>).
|
||||
if self.image_token is not None:
|
||||
escaped_img_token = re.escape(self.image_token)
|
||||
# Search for the image token in the chat template
|
||||
if re.search(escaped_img_token, self.processing_class.chat_template):
|
||||
prompts_text = [
|
||||
re.sub(rf"({escaped_img_token})+", self.image_token, text) for text in prompts_text
|
||||
]
|
||||
else:
|
||||
# If the chat template doesn't use the image token, we remove all instances of it + vision_end_token_id
|
||||
if self.vision_end_token_id is not None:
|
||||
escaped_eoi_token = re.escape(
|
||||
self.processing_class.tokenizer.decode([self.vision_end_token_id])
|
||||
)
|
||||
prompts_text = [
|
||||
re.sub(rf"({escaped_img_token})+{escaped_eoi_token}", "", text) for text in prompts_text
|
||||
]
|
||||
else:
|
||||
# If vision_end_token_id is None, just remove the image tokens
|
||||
prompts_text = [re.sub(rf"({escaped_img_token})+", "", text) for text in prompts_text]
|
||||
if images is not None:
|
||||
prompt_inputs = self.processing_class(text=prompts_text, padding=True, return_tensors="pt", **kwargs)
|
||||
prompt_inputs = super()._prepare_inputs(prompt_inputs)
|
||||
forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]}
|
||||
else:
|
||||
forward_kwargs = {}
|
||||
|
||||
# Generate completions using either vLLM or regular generation
|
||||
if self.use_vllm:
|
||||
@ -1185,6 +1133,7 @@ class GRPOTrainer(BaseTrainer):
|
||||
top_k=-1 if self.top_k is None else self.top_k,
|
||||
min_p=0.0 if self.min_p is None else self.min_p,
|
||||
max_tokens=self.max_completion_length,
|
||||
truncate_prompt_tokens=self.max_prompt_length,
|
||||
guided_decoding_regex=self.guided_decoding_regex,
|
||||
generation_kwargs=self.args.generation_kwargs,
|
||||
)
|
||||
@ -1223,6 +1172,7 @@ class GRPOTrainer(BaseTrainer):
|
||||
"top_k": -1 if self.top_k is None else self.top_k,
|
||||
"min_p": 0.0 if self.min_p is None else self.min_p,
|
||||
"max_tokens": self.max_completion_length,
|
||||
"truncate_prompt_tokens": self.max_prompt_length,
|
||||
"guided_decoding": guided_decoding,
|
||||
"logprobs": 0, # only return the logprob of the generated token
|
||||
}
|
||||
@ -1319,7 +1269,17 @@ class GRPOTrainer(BaseTrainer):
|
||||
|
||||
else:
|
||||
# Regular generation path
|
||||
prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
|
||||
generate_inputs = self.processing_class(
|
||||
text=prompts_text,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
padding_side="left",
|
||||
max_length=self.max_prompt_length,
|
||||
truncation=True,
|
||||
add_special_tokens=False,
|
||||
**kwargs,
|
||||
)
|
||||
generate_inputs = super()._prepare_inputs(generate_inputs)
|
||||
|
||||
with (
|
||||
profiling_context(self, "transformers.generate"),
|
||||
@ -1330,15 +1290,11 @@ class GRPOTrainer(BaseTrainer):
|
||||
FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(),
|
||||
):
|
||||
prompt_completion_ids = unwrapped_model.generate(
|
||||
input_ids=prompt_ids,
|
||||
attention_mask=prompt_mask,
|
||||
**forward_kwargs,
|
||||
generation_config=self.generation_config,
|
||||
disable_compile=True,
|
||||
**generate_inputs, generation_config=self.generation_config, disable_compile=True
|
||||
)
|
||||
# Compute prompt length and extract completion ids
|
||||
prompt_ids, prompt_mask = generate_inputs["input_ids"], generate_inputs["attention_mask"]
|
||||
prompt_length = prompt_ids.size(1)
|
||||
prompt_ids = prompt_completion_ids[:, :prompt_length]
|
||||
completion_ids = prompt_completion_ids[:, prompt_length:]
|
||||
|
||||
# Mask everything after the first EOS token
|
||||
|
@ -13,8 +13,10 @@
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import os
|
||||
import random
|
||||
import textwrap
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from operator import itemgetter
|
||||
@ -353,6 +355,13 @@ class KTOTrainer(BaseTrainer):
|
||||
model_adapter_name: Optional[str] = None,
|
||||
ref_adapter_name: Optional[str] = None,
|
||||
):
|
||||
if not os.environ.get("TRL_EXPERIMENTAL_SILENCE"):
|
||||
warnings.warn(
|
||||
"This trainer will soon be moved to trl.experimental and is a candidate for removal. If you rely on "
|
||||
"it and want it to remain, please share your comments here: "
|
||||
"https://github.com/huggingface/trl/issues/4223. Silence this warning by setting environment variable "
|
||||
"TRL_EXPERIMENTAL_SILENCE=1."
|
||||
)
|
||||
if type(args) is TrainingArguments:
|
||||
raise ValueError("Please use `KTOConfig` instead TrainingArguments.")
|
||||
|
||||
|
@ -412,3 +412,10 @@ class OnlineDPOConfig(TrainingArguments):
|
||||
|
||||
if hasattr(self.beta, "__len__") and len(self.beta) == 1:
|
||||
self.beta = self.beta[0]
|
||||
|
||||
if self.max_new_tokens >= self.max_length:
|
||||
warnings.warn(
|
||||
f"The configuration has `max_new_tokens` ({self.max_new_tokens}) >= `max_length` ({self.max_length}). "
|
||||
"This will cause prompts to be truncated or completely removed in the forward pass. "
|
||||
"To preserve prompts, ensure e.g. `max_length > max_new_tokens + 512`. ",
|
||||
)
|
||||
|
@ -57,8 +57,13 @@ from ..data_utils import apply_chat_template, is_conversational, maybe_apply_cha
|
||||
from ..extras.profiling import profiling_context
|
||||
from ..extras.vllm_client import VLLMClient
|
||||
from ..import_utils import is_vllm_available
|
||||
from ..models import create_reference_model, prepare_peft_model
|
||||
from ..models.utils import unwrap_model_for_generation
|
||||
from ..models import (
|
||||
create_reference_model,
|
||||
prepare_deepspeed,
|
||||
prepare_fsdp,
|
||||
prepare_peft_model,
|
||||
unwrap_model_for_generation,
|
||||
)
|
||||
from .base_trainer import BaseTrainer
|
||||
from .judges import BasePairwiseJudge
|
||||
from .online_dpo_config import OnlineDPOConfig
|
||||
@ -69,7 +74,6 @@ from .utils import (
|
||||
empty_cache,
|
||||
ensure_master_addr_port,
|
||||
pad,
|
||||
prepare_deepspeed,
|
||||
truncate_right,
|
||||
)
|
||||
|
||||
@ -206,6 +210,13 @@ class OnlineDPOTrainer(BaseTrainer):
|
||||
reward_model: Optional[Union[PreTrainedModel, nn.Module]] = None,
|
||||
reward_processing_class: Optional[PreTrainedTokenizerBase] = None,
|
||||
) -> None:
|
||||
if not os.environ.get("TRL_EXPERIMENTAL_SILENCE"):
|
||||
warnings.warn(
|
||||
"This trainer will soon be moved to trl.experimental and is a candidate for removal. If you rely on "
|
||||
"it and want it to remain, please share your comments here: "
|
||||
"https://github.com/huggingface/trl/issues/4223. Silence this warning by setting environment variable "
|
||||
"TRL_EXPERIMENTAL_SILENCE=1."
|
||||
)
|
||||
if ref_model is model:
|
||||
raise ValueError(
|
||||
"`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
|
||||
@ -581,24 +592,20 @@ class OnlineDPOTrainer(BaseTrainer):
|
||||
generation_kwargs = {k: v for k, v in generation_kwargs.items() if v is not None}
|
||||
self.generation_config = GenerationConfig(**generation_kwargs)
|
||||
|
||||
if self.is_deepspeed_enabled:
|
||||
if self.ref_model is not None:
|
||||
self.ref_model = prepare_deepspeed(
|
||||
self.ref_model, args.per_device_train_batch_size, args.fp16, args.bf16
|
||||
)
|
||||
# Prepare reward function models for DeepSpeed
|
||||
if self.reward_funcs is not None:
|
||||
for i, reward_func in enumerate(self.reward_funcs):
|
||||
if isinstance(reward_func, PreTrainedModel):
|
||||
if self.ref_model is not None:
|
||||
if self.is_deepspeed_enabled:
|
||||
self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
|
||||
elif self.is_fsdp_enabled:
|
||||
self.ref_model = prepare_fsdp(self.ref_model, self.accelerator)
|
||||
else:
|
||||
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||
if self.reward_funcs is not None:
|
||||
for i, reward_func in enumerate(self.reward_funcs):
|
||||
if isinstance(reward_func, PreTrainedModel):
|
||||
if self.is_deepspeed_enabled:
|
||||
self.reward_funcs[i] = prepare_deepspeed(reward_func, self.accelerator)
|
||||
else:
|
||||
if self.ref_model is not None:
|
||||
self.ref_model = self.ref_model.to(self.accelerator.device)
|
||||
# Prepare reward function models for FSDP/regular training
|
||||
if self.reward_funcs is not None:
|
||||
for i, reward_func in enumerate(self.reward_funcs):
|
||||
if isinstance(reward_func, PreTrainedModel):
|
||||
# Set device placement to True to make `prepare_model` move `reward_func` to device when using fsdp
|
||||
else:
|
||||
# set device placement to True to make `prepare_model` move `reward_func` to device when using fsdp
|
||||
self.reward_funcs[i] = self.accelerator.prepare_model(
|
||||
reward_func, evaluation_mode=True, device_placement=True
|
||||
)
|
||||
@ -826,8 +833,10 @@ class OnlineDPOTrainer(BaseTrainer):
|
||||
|
||||
def _generate_vllm_colocate(self, prompts, images=None):
|
||||
"""Generate completions using vLLM colocate mode"""
|
||||
# Update model weights if needed
|
||||
self._move_model_to_vllm()
|
||||
# Update model weights if needed - only after gradient accumulation completes
|
||||
if self.state.global_step != self._last_loaded_step:
|
||||
self._move_model_to_vllm()
|
||||
self._last_loaded_step = self.state.global_step
|
||||
|
||||
# Apply chat template if conversational
|
||||
if is_conversational({"prompt": prompts[0]}):
|
||||
@ -1227,10 +1236,12 @@ class OnlineDPOTrainer(BaseTrainer):
|
||||
# Get the logprobs of the completions from the model
|
||||
output = model(prompt_completion_ids, **model_kwargs)
|
||||
|
||||
# There is 1 offset, because the model predict the next token
|
||||
# There is 1 offset, because the model predicts the next token
|
||||
prompt_len = prompt_ids.size(1)
|
||||
start_idx = prompt_len - 1 if prompt_len > 0 else 0
|
||||
logits = output.logits[:, start_idx:-1]
|
||||
# Only slice off the last logit when we have a prompt, otherwise we need all logits
|
||||
end_idx = -1 if prompt_len > 0 else None
|
||||
logits = output.logits[:, start_idx:end_idx]
|
||||
|
||||
# Take the completion tokens logprob
|
||||
logprobs = torch.take_along_dim(logits.log_softmax(dim=-1), completion_ids.unsqueeze(-1), dim=2).squeeze(-1)
|
||||
|
@ -13,8 +13,10 @@
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import os
|
||||
import random
|
||||
import textwrap
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
@ -144,6 +146,13 @@ class ORPOTrainer(BaseTrainer):
|
||||
peft_config: Optional[dict] = None,
|
||||
compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
|
||||
):
|
||||
if not os.environ.get("TRL_EXPERIMENTAL_SILENCE"):
|
||||
warnings.warn(
|
||||
"This trainer will soon be moved to trl.experimental and is a candidate for removal. If you rely on "
|
||||
"it and want it to remain, please share your comments here: "
|
||||
"https://github.com/huggingface/trl/issues/4223. Silence this warning by setting environment variable "
|
||||
"TRL_EXPERIMENTAL_SILENCE=1."
|
||||
)
|
||||
if args.model_init_kwargs is None:
|
||||
model_init_kwargs = {}
|
||||
elif not isinstance(model, str):
|
||||
|
@ -17,6 +17,7 @@ import math
|
||||
import os
|
||||
import textwrap
|
||||
import time
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from pathlib import Path
|
||||
@ -159,6 +160,13 @@ class PPOTrainer(BaseTrainer):
|
||||
callbacks: Optional[list[TrainerCallback]] = None,
|
||||
peft_config: Optional["PeftConfig"] = None,
|
||||
) -> None:
|
||||
if not os.environ.get("TRL_EXPERIMENTAL_SILENCE"):
|
||||
warnings.warn(
|
||||
"This trainer will soon be moved to trl.experimental and is a candidate for removal. If you rely on "
|
||||
"it and want it to remain, please share your comments here: "
|
||||
"https://github.com/huggingface/trl/issues/4223. Silence this warning by setting environment variable "
|
||||
"TRL_EXPERIMENTAL_SILENCE=1."
|
||||
)
|
||||
if ref_model is model:
|
||||
raise ValueError(
|
||||
"`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
|
||||
|
@ -12,7 +12,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import textwrap
|
||||
import warnings
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional, Union
|
||||
@ -117,6 +119,13 @@ class PRMTrainer(BaseTrainer):
|
||||
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
||||
peft_config: Optional[dict] = None,
|
||||
):
|
||||
if not os.environ.get("TRL_EXPERIMENTAL_SILENCE"):
|
||||
warnings.warn(
|
||||
"This trainer will soon be moved to trl.experimental and is a candidate for removal. If you rely on "
|
||||
"it and want it to remain, please share your comments here: "
|
||||
"https://github.com/huggingface/trl/issues/4223. Silence this warning by setting environment variable "
|
||||
"TRL_EXPERIMENTAL_SILENCE=1."
|
||||
)
|
||||
if peft_config is not None or (is_peft_available() and isinstance(model, PeftModel)):
|
||||
model = prepare_peft_model(model, peft_config, args)
|
||||
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user