Compare commits

...

31 Commits

Author SHA1 Message Date
5d8ad5d538 Release: v0.7.0 2023-08-30 09:46:35 +00:00
9d09b3e107 TextEnvironments (#424)
* WIP skeleton

* minimal working poc

* cleanup

* rename variables

* quick typo fix

* add v1 masking (#429)

* add v1 masking

* working v1

* adapt from suggestion

* avoid warning `Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.`

* fix masking

- mask the responses from API call only

* quality

* address comments

* Update trl/environment/base.py

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>

* adapt a bit

* wip on tokenization/masking in textenv

* small fixes

* update viz

* add example

* print debug text and pass masks

* style

* format and move tensor to device

* update example

* update example

* This seems to work

* fix masking

* fix rich output to console

---------

Co-authored-by: Costa Huang <costa.huang@outlook.com>
Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
Co-authored-by: leandro <leandro.vonwerra@spoud.io>

* Add masking (#461)

* add v1 masking

* working v1

* adapt from suggestion

* avoid warning `Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.`

* fix masking

- mask the responses from API call only

* quality

* address comments

* Update trl/environment/base.py

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>

* adapt a bit

* wip on tokenization/masking in textenv

* small fixes

* update viz

* add example

* print debug text and pass masks

* style

* format and move tensor to device

* update example

* update example

* This seems to work

* fix masking

* fix rich output to console

* fix batched generation

* improve stopping criteria

* improve error handling in tool call

---------

Co-authored-by: younesbelkada <younesbelkada@gmail.com>
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: Costa Huang <costa.huang@outlook.com>

* fix uknown tool

* fix rewards and increase bs

* remove unused script

* ugly WIP fix

* do not return modified obj for in-place operations

* do not return modified obj for in-place operations

* clean up stopping criterium

* push updates

* push update

* format, add docs

* rename file

* add kwargs to reward fn

* simplify example

* simplify example

* bug fix

* add a trivia example

* pre-commit

* max tool response length

* fix regex for multi-line

* refactor tool exceptions

* fix exceptions in tool

* add docs

* fix style

* make rich optional

* add docstrings

* add  tests

* add TextEnv tests (WIP)

* update triviaqa code

* update docs

* refactor text env

* update tests (WIP)

* add end2end test

* update docs

* upload tool demo

* refactor

* customizable system prompt

* add text env docs

* update index and toc

* fix `TextHistory` show methods

* add max length

* fix style

* fix typo

* refactor to kwargs in init and tasks to queries

* kwargs for reward docs

* Update examples/triviaqa.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update examples/tool_demo.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update docs/source/learning_tools.mdx

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update docs/source/learning_tools.mdx

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update docs/source/learning_tools.mdx

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update docs/source/text_environments.md

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update examples/triviaqa.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update examples/triviaqa.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* move to tool folder

* remove assets

* remove tool demo

* move rich import test to import utils

* add copyright

* fixes for masks in ppo trainer

* add text env api docs

* make precommit + add ppo test with mask

* move examples and add python

* fix style

* update triviaqa example

* add more docs

* update docs

* Update docs/source/learning_tools.mdx

* Apply suggestions from code review

* precommit

---------

Co-authored-by: Costa Huang <costa.huang@outlook.com>
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: younesbelkada <younesbelkada@gmail.com>
Co-authored-by: leandro von werra <leandro@hf.co>
2023-08-30 11:44:06 +02:00
336d63eb80 [Docs] fix example README.md (#705) 2023-08-30 11:27:50 +02:00
7fc970983c [DPO] fix DPO ref_model=None (#703)
* fix by @tannonk

* Update trl/trainer/dpo_trainer.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* add import

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
2023-08-29 12:57:10 +02:00
d3bbee3ab8 set dev version (#685) 2023-08-24 11:04:07 +02:00
eb5465df7e Release: v0.6.0 (#684) 2023-08-24 10:18:46 +02:00
1c272240ac Simplify immutable TrainingArgs fix using dataclasses.replace (#682) 2023-08-24 09:50:48 +02:00
Wei
b095245830 fix PeftConfig loading from a remote repo. (#649)
* fix PeftConfig loading from a remote repo.

* failed to catch hf_hub_download() EntryNotFoundError.

At least in huggingface-hub 0.10.1, the error for "not found" is:
huggingface_hub.utils._errors.EntryNotFoundError: 404 Client Error

* pass precommit checks.

* replace some bare excepts with specific codes

* catch LocalEntryNotFoundError additionally.
2023-08-24 09:50:20 +02:00
c115453fba Update sft_llama2.py (#678)
Add argument num_workers. Fixed error on line 103 if streaming set = False
2023-08-23 16:56:31 +02:00
16f214c58d fix unmutable TrainingArguments issue (#676) 2023-08-23 10:54:59 +02:00
e9a437992e propagating eval_batch_size to TrainingArguments (#675)
Co-authored-by: Rahul Jha <rahuljha@netflix.com>
2023-08-23 10:52:25 +02:00
c837fbe5b9 Fix DPO blogpost thumbnail (#673) 2023-08-22 11:53:21 +02:00
01c4a35928 Denoising Diffusion Policy Optimization (#508)
* Broken first pre-draft

* Change structure to leverage user-definition of pipeline
 - reward function, pipeline and scheduler will be left to the user to define
 - pipeline and scheduler contract interfaces is what the framework will define
 - none of this actually works

* Incremental progress: trying to get the set-up running e2e

* Incemental progress: successfully running code

* Incremental progress: running setup
Next steps: fix accelerate gardient acc assertion error when we set value > 1

* Formatting and code standards

* Incremental prog: break down code a bit
- new config flag to notify code of async reward fetching
- break off image handling code and throw it on to user to define how to handle it
- more code restructuring

* Incremental progress:
1. More code sectioning off into own methods (more for readibility than anything else)

* Incremental progress:
1. clear up contracts
2. type the reward function and prompt function

* Code shuffling and expansion of tracker, accelerator config args to beyond wandb

* More small additions
Add tensorboard logging function
Remove wandb logging function for now
Consolidate the data that get's thrown to the logging function
Add README

* Formatting

* Formatting

* Remove print statement
Make tensorboard tracking the sole tracking for the training example

* 1. start of testing
2. more refactoring
3. start of docstrings
4. parameter rename

* Basic Tests
Formatting

* Docs according to the norm

* Doocs, credits and rename file

* docs and corrections

* Put example config to respectable state

* Add recent run params

* Correct the name of the library

* Move requirements to EXTRAS

* - Add license banners
- Guard import of DDPO functions with if_diffusers_available
- doc strings for output types

* Add snippet to pull weights from huggingface + banner

* Test if passes on CI/CD

* Minor refactor

* Test dummy unet

* Possible fix for randomly disappearing attribute

* Shuffling arrangement in hopes of meeting memory requirements

* Proper Names

* Appease windows memory allocator issues for the cpu device

* Remove print statements

* Update docs/source/ddpo_trainer.mdx

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Update docs/source/ddpo_trainer.mdx

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Add docstrings and correct url

* Spelling and grammar

* Add more documentation and commandline parsing for example script

* Markdown synatx correction

* Revert accidentally committed file and put the correct one

* More docs

* Remove subclassing and add docs for leftoover subclassing

* Put back subclassing

* Reward metadata and more docs

* Remove save_load_save flag

* Grammar

* Update trl/trainer/ddpo_trainer.py

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>

* Update tests/test_ddpo_trainer.py

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>

* Update setup.py

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>

* Update examples/scripts/stable_diffusion_tuning.py

Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>

* Edits to the readme for DDPO

* Renamed modelling_sd_base to modeling_sd_base

* Insert try and catch for bitsandbytes import

* Change to smaller model

* Correct tolerance for floating point comparison

* Remove dummy unet and move to check is isfinite

* 1. Expand interface to ensure other Stable Diffusion pipelines could be covered
2. remove extra identification

* 1. Remove most of the asserts except for one and add value error
2. Remove default run name

* Remove progress bar

* Docs

* Put back progress bar

* 1. Revert progress bar deletion completely
2. grammar
3. relocate line

* Experiment

* Remove experiment parts and format properly

* Change formatting and edit info in docs

* Grammar

* Refactor out most of nitty gritty of loading/saving from trainer to example model
Readme addition

* Docs additions

* 1. Proper formatting fr the test file
2. incorporatioon of pull frm hub if fails try local
3. doc strings for interface
4. highlight in the trainer, that this is only ready fr sd pipelines

* Resources for before and after

* Attempt at embedding images

* Post testing example script

* Consistent naming and document edits in light of new args

* Remove resources and add CDN links in html in doc file

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Leandro von Werra <lvwerra@users.noreply.github.com>
2023-08-21 19:24:52 +02:00
1aca98fbcf add check of arguments (#660) 2023-08-21 12:02:07 +02:00
029f961b7c Handle potentially long sequences with DataCollatorForCompletionOnlyLM (#644)
* avoid RuntimeError on long sequences

* add unittests and format

* remove dependency on external repo

* bug fix in DataCollatorForCompletionOnlyLM
2023-08-18 10:30:25 +02:00
8ec912ffa6 Add more args to SFT example (#642)
* add more args

* fix style issues
2023-08-18 10:15:43 +02:00
f360c37466 Allow for ref_model=None in DPOTrainer (#640)
* Update dpo_trainer.py

Make ref_model optional.

* add tests for ref_model=None

* better handling for ref_model=None

* Update dpo_trainer.py

Correct docstring

* move instantiation of self.ref_model closer to model

* use .disable_adapters instead of .get_base_model

* handle ref_model=None in get_batch_samples

* fix failing test in dpo_trainer due to disable_dropout_in_model

* Update trl/trainer/dpo_trainer.py

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
2023-08-18 10:02:16 +02:00
217313014b Update README.md (#657)
* Update README.md

fix reward modeling example

* Update README.md

more concise fix
2023-08-17 22:00:58 +02:00
b946e875b1 Resolve various typos throughout the docs (#654)
* Resolve various typos throughout the docs

I found the first few manually, and then found the rest via codespell

* HuggingFace -> Hugging Face

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
2023-08-17 12:27:54 +02:00
6dd50b45d8 Add checks on backward batch size (#651)
* Add checks on backward batch size

* add test case

* update test case

* Update citation
2023-08-17 10:35:44 +02:00
98120d6aeb Disable dropout in DPO Training (#639)
* disable dropout in dpo

* quick fix docs

* precommiot

* add disable_dropout_in_model to DPOTrainer

* disable_dropout -> disable_dropout_in_model

* .

* .
2023-08-14 14:40:45 +02:00
3b2c820db6 Add score scaling/normalization/clipping (#560)
* Add reward/score scaling/normalization/clipping

* Run pre-commit to fix styles and remove some dupe code

* Make sure score module and pretrained_model have the same dtype

* Add multi_adapter_rl_v2.py

* Add log_with

* Add more verbose help message for use_score_norm

* Fix score clipping for float16

* Minor fix
2023-08-10 10:30:56 +02:00
25fd6f2313 Move repo (#628)
* update actions

* update references
2023-08-09 17:48:25 +02:00
3f1477cdc0 Improve docs (#612)
* WIP

* improve inference docs

* improve training faq

* update toctree

* fix toctree

* fix improve blog

* improve blog

* fix customization

* reword faq a bit

* reword inference a bit

* add references back

* integrate feedback from code review

* fix link in html
2023-08-08 11:45:16 +02:00
2cff1e4385 Allow already tokenized sequences for response_template in DataCollatorForCompletionOnlyLM (#622)
* Allow tokenized ids in DataCollatorForCompletionOnlyLM. Add test and docs

* Formatting

* Documentation

* Remove unused code from test

---------

Co-authored-by: Ivan Sanchez <ivan.sanchez@zyte.com>
2023-08-08 11:33:12 +02:00
d7d7902938 use log_with argument (#620) 2023-08-08 10:13:22 +02:00
77b0cc1707 [DPO] stack-llama-2 training scripts (#611)
* initial stack-llama-2 scripts

* removed unused function

* add accelerate

* link to stack-llama-2 code

* running the model

* pre-commit fixes

* use the merge_peft script

* Add section on logged metrics
2023-08-07 14:36:16 +02:00
17f22c1c20 Add docs explaining logged metrics (#616) 2023-08-04 12:50:39 -04:00
e448bb69f0 [Modeling] Add token support for hf_hub_download (#604)
* add token support for hf_hub_download

* allow to pass it to from_pretrained
2023-08-03 12:49:31 +02:00
9aa4e3ce2b set dev version (#608) 2023-08-02 10:43:27 +02:00
ca8a508913 Release: 0.5.0 (#607) 2023-08-02 10:31:43 +02:00
66 changed files with 5153 additions and 205 deletions

View File

@ -13,7 +13,6 @@ jobs:
with:
commit_sha: ${{ github.sha }}
package: trl
repo_owner: lvwerra
version_tag_suffix: ""
secrets:
token: ${{ secrets.HUGGINGFACE_PUSH }}

View File

@ -14,5 +14,4 @@ jobs:
commit_sha: ${{ github.event.pull_request.head.sha }}
pr_number: ${{ github.event.number }}
package: trl
repo_owner: lvwerra
version_tag_suffix: ""

View File

@ -7,7 +7,7 @@ on:
jobs:
close_stale_issues:
name: Close Stale Issues
if: github.repository == 'lvwerra/trl'
if: github.repository == 'huggingface/trl'
runs-on: ubuntu-latest
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

View File

@ -17,7 +17,7 @@ authors:
family-names: Thrush
- given-names: Nathan
family-names: Lambert
repository-code: 'https://github.com/lvwerra/trl'
repository-code: 'https://github.com/huggingface/trl'
abstract: "With trl you can train transformer language models with Proximal Policy Optimization (PPO). The library is built on top of the transformers library by \U0001F917 Hugging Face. Therefore, pre-trained language models can be directly loaded via transformers. At this point, most decoder and encoder-decoder architectures are supported."
keywords:
- rlhf

View File

@ -6,14 +6,14 @@
> Full stack transformer language models with reinforcement learning.
<p align="center">
<a href="https://github.com/lvwerra/trl/blob/main/LICENSE">
<img alt="License" src="https://img.shields.io/github/license/lvwerra/trl.svg?color=blue">
<a href="https://github.com/huggingface/trl/blob/main/LICENSE">
<img alt="License" src="https://img.shields.io/github/license/huggingface/trl.svg?color=blue">
</a>
<a href="https://huggingface.co/docs/trl/index">
<img alt="Documentation" src="https://img.shields.io/website/http/huggingface.co/docs/trl/index.svg?down_color=red&down_message=offline&up_message=online">
</a>
<a href="https://github.com/lvwerra/trl/releases">
<img alt="GitHub release" src="https://img.shields.io/github/release/lvwerra/trl.svg">
<a href="https://github.com/huggingface/trl/releases">
<img alt="GitHub release" src="https://img.shields.io/github/release/huggingface/trl.svg">
</a>
</p>
@ -24,7 +24,7 @@
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/TRL-readme.png">
</div>
`trl` is a full stack library where we provide a set of tools to train transformer language models with Reinforcement Learning, from the Supervised Fine-tuning step (SFT), Reward Modeling step (RM) to the Proximal Policy Optimization (PPO) step. The library is built on top of the [`transformers`](https://github.com/huggingface/transformers) library by 🤗 Hugging Face. Therefore, pre-trained language models can be directly loaded via `transformers`. At this point most of decoder architectures and encoder-decoder architectures are supported. Refer to the documentation or the `examples/` folder for example code snippets and how to run these tools.
`trl` is a full stack library where we provide a set of tools to train transformer language models and stable diffusion models with Reinforcement Learning, from the Supervised Fine-tuning step (SFT), Reward Modeling step (RM) to the Proximal Policy Optimization (PPO) step. The library is built on top of the [`transformers`](https://github.com/huggingface/transformers) library by 🤗 Hugging Face. Therefore, pre-trained language models can be directly loaded via `transformers`. At this point most of decoder architectures and encoder-decoder architectures are supported. Refer to the documentation or the `examples/` folder for example code snippets and how to run these tools.
**Highlights:**
@ -32,7 +32,7 @@
- [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer): A light wrapper around `transformers` Trainer to easily fine-tune language models for human preferences (Reward Modeling).
- [`PPOTrainer`](https://huggingface.co/docs/trl/trainer#trl.PPOTrainer): A PPO trainer for language models that just needs (query, response, reward) triplets to optimise the language model.
- [`AutoModelForCausalLMWithValueHead`](https://huggingface.co/docs/trl/models#trl.AutoModelForCausalLMWithValueHead) & [`AutoModelForSeq2SeqLMWithValueHead`](https://huggingface.co/docs/trl/models#trl.AutoModelForSeq2SeqLMWithValueHead): A transformer model with an additional scalar output for each token which can be used as a value function in reinforcement learning.
- [Examples](https://github.com/lvwerra/trl/tree/main/examples): Train GPT2 to generate positive movie reviews with a BERT sentiment classifier, full RLHF using adapters only, train GPT-j to be less toxic, [Stack-Llama example](https://huggingface.co/blog/stackllama), etc.
- [Examples](https://github.com/huggingface/trl/tree/main/examples): Train GPT2 to generate positive movie reviews with a BERT sentiment classifier, full RLHF using adapters only, train GPT-j to be less toxic, [Stack-Llama example](https://huggingface.co/blog/stackllama), etc.
## How PPO works
Fine-tuning a language model via PPO consists of roughly three steps:
@ -60,7 +60,7 @@ pip install trl
### From source
If you want to run the examples in the repository a few additional libraries are required. Clone the repository and install it with pip:
```bash
git clone https://github.com/lvwerra/trl.git
git clone https://github.com/huggingface/trl.git
cd trl/
pip install .
```
@ -106,7 +106,7 @@ from transformers import AutoModelForSequenceClassification, AutoTokenizer
from trl import RewardTrainer
# load model and dataset - dataset needs to be in a specific format
model = AutoModelForSequenceClassification.from_pretrained("gpt2")
model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=1)
tokenizer = AutoTokenizer.from_pretrained("gpt2")
...
@ -162,16 +162,6 @@ reward = [torch.tensor(1.0)]
train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward)
```
### Advanced example: IMDB sentiment
For a detailed example check out the example python script `examples/scripts/sentiment_tuning.py`, where GPT2 is fine-tuned to generate positive movie reviews. An few examples from the language models before and after optimisation are given below:
<div style="text-align: center">
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/table_imdb_preview.png" width="800">
<p style="text-align: center;"> <b>Figure:</b> A few review continuations before and after optimisation. </p>
</div>
Have a look at more examples inside [`examples/`](https://github.com/lvwerra/trl/tree/main/examples) folder.
## References
### Proximal Policy Optimisation
@ -184,11 +174,11 @@ The language models utilize the `transformers` library by 🤗 Hugging Face.
```bibtex
@misc{vonwerra2022trl,
author = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert},
author = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang},
title = {TRL: Transformer Reinforcement Learning},
year = {2020},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/lvwerra/trl}}
howpublished = {\url{https://github.com/huggingface/trl}}
}
```

View File

@ -5,8 +5,12 @@
title: Quickstart
- local: installation
title: Installation
- local: how_to_train
title: PPO Training FAQ
- local: use_model
title: Use Trained Models
- local: customization
title: Customize your Training
title: Customize the Training
- local: logging
title: Understanding Logs
title: Get started
@ -23,6 +27,10 @@
title: Best of N Sampling
- local: dpo_trainer
title: DPO Trainer
- local: ddpo_trainer
title: Denoising Diffusion Policy Optimization
- local: text_environments
title: Text Environments
title: API
- sections:
- local: sentiment_tuning
@ -33,6 +41,8 @@
title: Detoxifying a Language Model
- local: using_llama_models
title: Training StackLlama
- local: learning_tools
title: Learning to Use Tools
- local: multi_adapter_rl
title: Multi Adapter RLHF
title: Examples

View File

@ -16,7 +16,7 @@ Then make sure you have selected multi-gpu / multi-node setup. You can then run
accelerate launch your_script.py
```
Refer to the [examples page](https://github.com/lvwerra/trl/tree/main/examples) for more details
Refer to the [examples page](https://github.com/huggingface/trl/tree/main/examples) for more details
## Use different optimizers
@ -180,18 +180,21 @@ else:
sentiment_pipe = pipeline("sentiment-analysis", model="lvwerra/distilbert-imdb", device=device)
```
## Use torch distributed
torch.distributed package provides PyTorch natives method to distribute a network over several machines (mostly useful if there are several GPU nodes). It copies the model on each GPU, runs the forward and backward on each and then applies the mean of gradient of all GPUs for each one. If running torch 1.XX, you can call `torch.distributed.launch`, like
`python -m torch.distributed.launch --nproc_per_node=1 reward_summarization.py --bf16`
## Use score scaling/normalization/clipping
As suggested by [Secrets of RLHF in Large Language Models Part I: PPO](https://arxiv.org/abs/2307.04964), we support score (aka reward) scaling/normalization/clipping to improve training stability via `PPOConfig`:
```python
from trl import PPOConfig
For torch 2.+ `torch.distributed.launch` is deprecated and one needs to run:
`torchrun --nproc_per_node=1 reward_summarization.py --bf16`
or
`python -m torch.distributed.run --nproc_per_node=1 reward_summarization.py --bf16`
Note that using `python -m torch.distributed.launch --nproc_per_node=1 reward_summarization.py --bf16` with torch 2.0 ends in
ppo_config = {
use_score_scaling=True,
use_score_norm=True,
score_clip=0.5,
}
config = PPOConfig(**ppo_config)
```
ValueError: Some specified arguments are not used by the HfArgumentParser: ['--local-rank=0']
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 194889) of binary: /home/ubuntu/miniconda3/envs/trl/bin/python
To run `sentiment_tuning.py`, you can use the following command:
```
python examples/scripts/sentiment_tuning.py --log_with wandb --use_score_scaling --use_score_norm --score_clip 0.5
```

View File

@ -0,0 +1,144 @@
# Denoising Diffusion Policy Optimization
## The why
| Before | After DDPO finetuning |
| --- | --- |
| <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/pre_squirrel.png"/></div> | <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/post_squirrel.png"/></div> |
| <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/pre_crab.png"/></div> | <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/post_crab.png"/></div> |
| <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/pre_starfish.png"/></div> | <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/post_starfish.png"/></div> |
## Getting started with Stable Diffusion finetuning with reinforcement learning
The machinery for finetuning of Stable Diffusion models with reinforcement learning makes heavy use of HuggingFace's `diffusers`
library. A reason for stating this is that getting started requires a bit of familiarity with the `diffusers` library concepts, mainly two of them - pipelines and schedulers.
Right out of the box (`diffusers` library), there isn't a `Pipeline` nor a `Scheduler` instance that is suitable for finetuning with reinforcement learning. Some adjustments need to made.
There is a pipeline interface that is provided by this library that is required to be implemented to be used with the `DDPOTrainer`, which is the main machinery for fine-tuning Stable Diffusion with reinforcement learning. **Note: Only the StableDiffusion architecture is supported at this point.**
There is a default implementation of this interface that you can use out of the box. Assuming the default implementation is sufficient and/or to get things moving, refer to the training example alongside this guide.
The point of the interface is to fuse the pipeline and the scheduler into one object which allows for minimalness in terms of having the constraints all in one place. The interface was designed in hopes of catering to pipelines and schedulers beyond the examples in this repository and elsewhere at this time of writing. Also the scheduler step is a method of this pipeline interface and this may seem redundant given that the raw scheduler is accessible via the interface but this is the only way to constrain the scheduler step output to an output type befitting of the algorithm at hand (DDPO).
For a more detailed look into the interface and the associated default implementation, go [here](https://github.com/lvwerra/trl/tree/main/trl/models/modeling_sd_base.py)
Note that the default implementation has a LoRA implementation path and a non-LoRA based implementation path. The LoRA flag enabled by default and this can be turned off by passing in the flag to do so. LORA based training is faster and the LORA associated model hyperparameters responsible for model convergence aren't as finicky as non-LORA based training.
Also in addition, there is the expectation of providing a reward function and a prompt function. The reward function is used to evaluate the generated images and the prompt function is used to generate the prompts that are used to generate the images.
## Getting started with `examples/scripts/stable_diffusion_tuning.py`
The `stable_diffusion_tuning.py` script is a working example of using the `DDPO` trainer to finetune a Stable Diffusion model. This example explicitly configures a small subset of the overall parameters associated with the config object (`DDPOConfig`).
**Note:** one A100 GPU is recommended to get this running. Anything below a A100 will not be able to run this example script and even if it does via relatively smaller sized parameters, the results will most likely be poor.
Almost every configuration parameter has a default. There is only one commandline flag argument that is required of the user to get things up and running. The user is expected to have a [huggingface user access token](https://huggingface.co/docs/hub/security-tokens) that will be used to upload the model post finetuning to HuggingFace hub. The following bash command is to be entered to get things running
```batch
python stable_diffusion_tuning.py --hf_user_access_token <token>
```
Again, the script uses a small subset of parameters to configure the trainer. And all of these are configurable via the commandline.
It should be noted (in general) that because the trainer uses `accelerate` as a core component, some parameters are those of accelerate's.
The commandline flags that are associated with the example script's parameters are listed below.
|parameter|description|default|
| ---- | ---- | ---- |
|`--hf_hub_aesthetic_model_id`|The HuggingFace model hub id of the aesthetic scorer model|`"trl-lib/ddpo-aesthetic-predictor"`|
|`--hf_hub_aesthetic_model_filename`|The filename of the aesthetic scorer model |`"aesthetic-model.pth"`|
|`--pretrained_model`|The string id of the pretrained Stable Diffusion model|`"runwayml/stable-diffusion-v1-5"`|
|`--pretrained_revision`|The revision of the pretrained Stable Diffusion model|`"main"`|
|`--num_epochs`|The number of epochs to train for|`200`|
|`--train_batch_size`|The batch size to use for training|`3`|
|`--sample_batch_size`|The batch size to use for sampling|`6`|
|`--gradient_accumulation_steps`|The number of accelerator based gradient accumulation steps to use|`1`|
|`--sample_num_steps`| The number of steps to sample for|`50`|
|`--sample_num_batches_per_epoch`|The number of batches to sample per epoch|`4`|
|`--log_with`|The logger to use. Either `wandb` or `tensorboard`|`wandb`|
|`--per_prompt_stat_tracking`|Whether to track stats per prompt. If false, advantages will be calculated using the mean and std of the entire batch as opposed to tracking per prompt|`True`|
|`--per_prompt_stat_tracking_buffer_size`|The size of the buffer to use for tracking stats per prompt|`32`|
|`--tracker_project_name`|The name of the project for use on the tracking platform (wandb/tensorboard/etc) |`"stable_diffusion_training"`|
| `--logging_dir`|The directory to use for logging|`"logs"`|
| `--project_dir`|The directory to use for saving the model|`"save"`|
| `--automatic_checkpoint_naming`|Whether to automatically name model checkpoints|`True`|
| `--total_limit`| Number of checkpoints to keep before overwriting old ones|`5`|
| `--hf_hub_model_id`|The HuggingFace model hub id to use for saving the model|`"ddpo-finetuned-sd-model"`|
| `--hf_user_access_token`| The HuggingFace user access token|`None`|
The following are things to keep in mind (The code checks this for you as well) in general while configuring the trainer (beyond the use case of using the example script)
- The configurable sample batch size should be greater than or equal to the configurable training batch size
- The configurable sample batch size must be divisible by the configurable train batch size
- The configurable sample batch size must be divisible by both the configurable gradient accumulation steps and the configurable accelerator processes count
## Setting up the image logging hook function
Expect the function to be given a list of lists of the form
```python
[[image, prompt, prompt_metadata, rewards, reward_metadata], ...]
```
and `image`, `prompt`, `prompt_metadata`, `rewards`, `reward_metadata` are batched.
The last list in the lists of lists represents the last sample batch. You are likely to want to log this one
While you are free to log however you want the use of `wandb` or `tensorboard` is recommended.
### Key terms
- `rewards` : The rewards/score is a numerical associated with the generated image and is key to steering the RL process
- `reward_metadata` : The reward metadata is the metadata associated with the reward. Think of this as extra information payload delivered alongside the reward
- `prompt` : The prompt is the text that is used to generate the image
- `prompt_metadata` : The prompt metadata is the metadata associated with the prompt. A situation where this will not be empty is when the reward model comprises of a [`FLAVA`](https://huggingface.co/docs/transformers/model_doc/flava) setup where questions and ground answers (linked to the generated image) are expected with the generated image (See here: https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/rewards.py#L45)
- `image` : The image generated by the Stable Diffusion model
Example code for logging sampled images with `wandb` is given below.
```python
# for logging these images to wandb
def image_outputs_hook(image_data, global_step, accelerate_logger):
# For the sake of this example, we only care about the last batch
# hence we extract the last element of the list
result = {}
images, prompts, _, rewards, _ = image_data[-1]
for i, image in enumerate(images):
pil = Image.fromarray(
(image.cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8)
)
pil = pil.resize((256, 256))
result[f"{prompts[i]:.25} | {rewards[i]:.2f}"] = [pil]
accelerate_logger.log_images(
result,
step=global_step,
)
```
### Using the finetuned model
Assuming you've done with all the epochs and have pushed up your model to the hub, you can use the finetuned model as follows
```python
import torch
from trl import DefaultDDPOStableDiffusionPipeline
pipeline = DefaultDDPOStableDiffusionPipeline("metric-space/ddpo-finetuned-sd-model")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# memory optimization
pipeline.vae.to(device, torch.float16)
pipeline.text_encoder.to(device, torch.float16)
pipeline.unet.to(device, torch.float16)
prompts = ["squirrel", "crab", "starfish", "whale","sponge", "plankton"]
results = pipeline(prompts)
for prompt, image in zip(prompts,results.images):
image.save(f"{prompt}.png")
```
## Credits
This work is heavily influenced by the repo [here](https://github.com/kvablack/ddpo-pytorch) and the associated paper [Training Diffusion Models
with Reinforcement Learning by Kevin Black, Michael Janner, Yilan Du, Ilya Kostrikov, Sergey Levine](https://arxiv.org/abs/2305.13301).

View File

@ -4,12 +4,12 @@ Language models (LMs) are known to sometimes generate toxic outputs. In this exa
Read this section to follow our investigation on how we can reduce toxicity in a wide range of LMs, from 125m parameters to 6B parameters!
Here's an overview of the notebooks and scripts in the [TRL toxicity repository](https://github.com/lvwerra/trl/tree/main/examples/toxicity/scripts) as well as the link for the interactive demo:
Here's an overview of the notebooks and scripts in the [TRL toxicity repository](https://github.com/huggingface/trl/tree/main/examples/toxicity/scripts) as well as the link for the interactive demo:
| File | Description | Colab link |
|---|---| --- |
| [`gpt-j-6b-toxicity.py`](https://github.com/lvwerra/trl/blob/main/examples/research_projects/toxicity/scripts/gpt-j-6b-toxicity.py) | Detoxify `GPT-J-6B` using PPO | x |
| [`evaluate-toxicity.py`](https://github.com/lvwerra/trl/blob/main/examples/research_projects/toxicity/scripts/evaluate-toxicity.py) | Evaluate de-toxified models using `evaluate` | x |
| [`gpt-j-6b-toxicity.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/toxicity/scripts/gpt-j-6b-toxicity.py) | Detoxify `GPT-J-6B` using PPO | x |
| [`evaluate-toxicity.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/toxicity/scripts/evaluate-toxicity.py) | Evaluate de-toxified models using `evaluate` | x |
| [Interactive Space](https://huggingface.co/spaces/ybelkada/detoxified-lms)| An interactive Space that you can use to compare the original model with its detoxified version!| x |
## Context
@ -174,7 +174,7 @@ Below are few generation examples of `gpt-j-6b-detox` model:
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl-toxicity-examples.png">
</div>
The evaluation script can be found [here](https://github.com/lvwerra/trl/blob/main/examples/research_projects/toxicity/scripts/evaluate-toxicity.py).
The evaluation script can be found [here](https://github.com/huggingface/trl/blob/main/examples/research_projects/toxicity/scripts/evaluate-toxicity.py).
### Discussions

View File

@ -1,6 +1,6 @@
# DPO Trainer
TRL supports the DPO Trainer for training language models from preference data, as described in the paper [Direct Preference Optimization: Your Language Model is Secretly a Reward Model](https://arxiv.org/abs/2305.18290) by Rafailov et al., 2023. For a full example have a look at [`examples/dpo.py`](https://github.com/lvwerra/trl/blob/main/examples/dpo.py).
TRL supports the DPO Trainer for training language models from preference data, as described in the paper [Direct Preference Optimization: Your Language Model is Secretly a Reward Model](https://arxiv.org/abs/2305.18290) by Rafailov et al., 2023. For a full example have a look at [`examples/dpo.py`](https://github.com/huggingface/trl/blob/main/examples/dpo.py).
The first step as always is to train your SFT model, to ensure the data we train on is in-distribution for the DPO algorithm.
@ -77,6 +77,15 @@ dpo_trainer.train()
Note that the `beta` is the temperature parameter for the DPO loss, typically something in the range of `0.1` to `0.5`. We ignore the reference model as `beta` -> 0.
## Logging
While training and evaluating we record the following reward metrics:
* `rewards/chosen`: the mean difference between the log probabilities of the policy model and the reference model for the chosen responses scaled by beta
* `rewards/rejected`: the mean difference between the log probabilities of the policy model and the reference model for the rejected responses scaled by beta
* `rewards/accuracies`: mean of how often the chosen rewards are > than the corresponding rejected rewards
* `rewards/margins`: the mean difference between the chosen and corresponding rejected rewards
## DPOTrainer
[[autodoc]] DPOTrainer

View File

@ -0,0 +1,66 @@
# Training FAQ
## What Metrics Should I Look at?
When performing classical supervised fine-tuning of language models, the loss (especially the validation loss) serves as a good indicator of the training progress. However, in Reinforcement Learning (RL), the loss becomes less informative about the model's performance, and its value may fluctuate while the actual performance improves.
To address this, we recommend focusing on two key metrics first:
**Mean Reward**: The primary goal is to maximize the reward achieved by the model during RL training.
**Objective KL Divergence**: KL divergence (Kullback-Leibler divergence) measures the dissimilarity between two probability distributions. In the context of RL training, we use it to quantify the difference between the current model and a reference model. Ideally, we want to keep the KL divergence between 0 and 10 to ensure the model's generated text remains close to what the reference model produces.
However, there are more metrics that can be useful for debugging, checkout the [logging section](logging).
## Why Do We Use a Reference Model, and What's the Purpose of KL Divergence?
When training RL models, optimizing solely for reward may lead to unexpected behaviors, where the model exploits the environment in ways that don't align with good language generation. In the case of RLHF, we use a reward model trained to predict whether a generated text is highly ranked by humans.
However, the RL model being optimized against the reward model may learn patterns that yield high reward but do not represent good language. This can result in extreme cases where the model generates texts with excessive exclamation marks or emojis to maximize the reward. In some worst-case scenarios, the model may generate patterns completely unrelated to natural language yet receive high rewards, similar to adversarial attacks.
<div style="text-align: center">
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/kl-example.png">
<p style="text-align: center;"> <b>Figure:</b> Samples without a KL penalty from <a href="https://arxiv.org/pdf/1909.08593.pdf">https://arxiv.org/pdf/1909.08593.pdf</a>. </p>
</div>
To address this issue, we add a penalty to the reward function based on the KL divergence between the current model and the reference model. By doing this, we encourage the model to stay close to what the reference model generates.
## What Is the Concern with Negative KL Divergence?
If you generate text by purely sampling from the model distribution things work fine in general. But when you use the `generate` method there are a few caveats because it does not always purely sample depending on the settings which can cause KL-divergence to go negative. Essentially when the active model achieves `log_p_token_active < log_p_token_ref` we get negative KL-div. This can happen in a several cases:
- **top-k sampling**: the model can smooth out the probability distribution causing the top-k tokens having a smaller probability than those of the reference model but they still are selected
- **min_length**: this ignores the EOS token until `min_length` is reached. thus the model can assign a very high log prob to the EOS token and very low prob to all others until min_length is reached
- **batched generation**: finished sequences in a batch are padded until all generations are finished. The model can learn to assign very low probabilities to the padding tokens unless they are properly masked or removed.
These are just a few examples. Why is negative KL an issue? The total reward `R` is computed `R = r - beta * KL` so if the model can learn how to drive KL-divergence negative it effectively gets a positive reward. In many cases it can be much easier to exploit such a bug in the generation than actually learning the reward function. In addition the KL can become arbitrarily small thus the actual reward can be very small compared to it.
So how should you generate text for PPO training? Let's have a look!
## How to generate text for training?
In order to avoid the KL issues described above we recommend to use the following settings:
```python
generation_kwargs = {
"min_length": -1, # don't ignore the EOS token (see above)
"top_k": 0.0, # no top-k sampling
"top_p": 1.0, # no nucleus sampling
"do_sample": True, # yes, we want to sample
"pad_token_id": tokenizer.eos_token_id, # most decoder models don't have a padding token - use EOS token instead
"max_new_tokens": 32, # specify how many tokens you want to generate at most
}
```
With these settings we usually don't encounter any issues. You can also experiments with other settings but if you encounter issues with negative KL-divergence try to go back to these and see if they persist.
## How can debug your own use-case?
Debugging the RL pipeline can be challenging due to its complexity. Here are some tips and suggestions to make the process easier:
- **Start from a working example**: Begin with a working example from the trl repository and gradually modify it to fit your specific use-case. Changing everything at once can make it difficult to identify the source of potential issues. For example, you can start by replacing the model in the example and once you figure out the best hyperparameters try to switch to your dataset and reward model. If you change everything at once you won't know where a potential problem comes from.
- **Start small, scale later**: Training large models can be very slow and take several hours or days until you see any improvement. For debugging this is not a convenient timescale so try to use small model variants during the development phase and scale up once that works. That being said you sometimes have to be careful as small models might not have the capacity to solve a complicated task either.
- **Start simple**: Try to start with a minimal example and build complexity from there. Your use-case might require for example a complicated reward function consisting of many different rewards - try to use one signal first and see if you can optimize that and then add more complexity after that.
- **Inspect the generations**: It's always a good idea to inspect what the model is generating. Maybe there is a big in your post-processing or your prompt. Due to bad settings you might cut-off generations too soon. These things are very hard to see on the metrics but very obvious if you look at the generations.
- **Inspect the reward model**: If you reward is not improving over time maybe there's an issue with the reward model. You can look at extreme cases to see if it does what it should: e.g. in the sentiment case you can check if simple positive and negative examples really get different rewards. And you can look at the distribution of your dataset. Finally, maybe the reward is dominated by the query which the model can't affect so you might need to normalize this (e.g. reward of query+response minus reward of the query).
These are just a few tips that we find helpful - if you have more useful tricks feel free to open a PR to add them as well!

View File

@ -13,19 +13,45 @@ The library is integrated with 🤗 [transformers](https://github.com/huggingfac
Check the appropriate sections of the documentation depending on your needs:
API documentation:
## API documentation
- [Model Classes](models): *A brief overview of what each public model class does.*
- [`SFTTrainer`](sft_trainer): *Supervise Fine-tune your model easily with `SFTTrainer`*
- [`RewardTrainer`](reward_trainer): *Train easily your reward model using `RewardTrainer`.*
- [`PPOTrainer`](trainer): *Further fine-tune the supervised fine-tuned model using PPO algorithm*
- [Best-of-N Samppling](best-of-n): *Use best of n sampling as an alternative way to sample predictions from your active model*
- [Best-of-N Sampling](best-of-n): *Use best of n sampling as an alternative way to sample predictions from your active model*
- [`DPOTrainer`](trainer): *Direct Preference Optimization training using `DPOTrainer`.*
- [`TextEnvironment`](text_environment): *Text environment to train your model using tools with RL.*
Examples:
## Examples
- [Sentiment Tuning](sentiment_tuning): *Fine tune your model to generate positive movie contents*
- [Training with PEFT](lora_tuning_peft): *Memory efficient RLHF training using adapters with PEFT*
- [Detoxifying LLMs](detoxifying_a_lm): *Detoxify your language model through RLHF*
- [StackLlama](using_llama_models): *End-to-end RLHF training of a Llama model on Stack exchange dataset*
- [Learning with Tools](learning_tools): *Walkthrough of using `TextEnvironments`*
- [Multi-Adapter Training](multi_adapter_rl): *Use a single base model and multiple adapters for memory efficient end-to-end training*
## Blog posts
<div class="mt-10">
<div class="w-full flex flex-col space-y-4 md:space-y-0 md:grid md:grid-cols-2 md:gap-y-4 md:gap-x-5">
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/rlhf">
<img src="https://raw.githubusercontent.com/huggingface/blog/main/assets/120_rlhf/thumbnail.png" alt="thumbnail">
<p class="text-gray-700">Illustrating Reinforcement Learning from Human Feedback</p>
</a>
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/trl-peft">
<img src="https://github.com/huggingface/blog/blob/main/assets/133_trl_peft/thumbnail.png?raw=true" alt="thumbnail">
<p class="text-gray-700">Fine-tuning 20B LLMs with RLHF on a 24GB consumer GPU</p>
</a>
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/stackllama">
<img src="https://github.com/huggingface/blog/blob/main/assets/138_stackllama/thumbnail.png?raw=true" alt="thumbnail">
<p class="text-gray-700">StackLLaMA: A hands-on guide to train LLaMA with RLHF</p>
</a>
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/dpo-trl">
<img src="https://github.com/huggingface/blog/blob/main/assets/157_dpo_trl/dpo_thumbnail.png?raw=true" alt="thumbnail">
<p class="text-gray-700">Fine-tune Llama 2 with DPO</p>
</a>
</div>
</div>

View File

@ -12,7 +12,7 @@ pip install trl
You can also install the latest version from source. First clone the repo and then run the installation with `pip`:
```bash
git clone https://github.com/lvwerra/trl.git
git clone https://github.com/huggingface/trl.git
cd trl/
pip install -e .
```

View File

@ -0,0 +1,229 @@
# Learning Tools (Experimental 🧪)
Using Large Language Models (LLMs) with tools has been a popular topic recently with awesome works such as [ToolFormer](https://arxiv.org/abs/2302.04761) and [ToolBench](https://arxiv.org/pdf/2305.16504.pdf). In TRL, we provide a simple example of how to teach LLM to use tools with reinforcement learning.
Here's an overview of the scripts in the [trl repository](https://github.com/lvwerra/trl/tree/main/examples/research_projects/tools):
| File | Description |
|---|---|
| [`calculator.py`](https://github.com/lvwerra/trl/blob/main/examples/research_projects/tools/calculator.py) | Script to train LLM to use a calculator with reinforcement learning. |
| [`triviaqa.py`](https://github.com/lvwerra/trl/blob/main/examples/research_projects/tools/triviaqa.py) | Script to train LLM to use a wiki tool to answer questions. |
| [`python_interpreter.py`](https://github.com/lvwerra/trl/blob/main/examples/research_projects/tools/python_interpreter.py) | Script to train LLM to use python interpreter to solve math puzzles. |
<Tip warning={true}>
Note that the scripts above rely heavily on the `TextEnvironment` API which is still under active development. The API may change in the future. Please see [`TextEnvironment`](text_environment) for the related docs.
</Tip>
## Learning to Use a Calculator
The rough idea is as follows:
1. Load a tool such as [ybelkada/simple-calculator](https://huggingface.co/spaces/ybelkada/simple-calculator) that parse a text calculation like `"14 + 34"` and return the calulated number:
```python
from transformers import AutoTokenizer, load_tool
tool = load_tool("ybelkada/simple-calculator")
tool_fn = lambda text: str(round(float(tool(text)), 2)) # rounding to 2 decimal places
```
1. Define a reward function that returns a positive reward if the tool returns the correct answer. In the script we create a dummy reward function like `reward_fn = lambda x: 1`, but we override the rewards directly later.
1. Create a prompt on how to use the tools
```python
# system prompt
prompt = """\
What is 13.1-3?
<request><SimpleCalculatorTool>13.1-3<call>10.1<response>
Result=10.1<submit>
What is 4*3?
<request><SimpleCalculatorTool>4*3<call>12<response>
Result=12<submit>
What is 12.1+1?
<request><SimpleCalculatorTool>12.1+1<call>13.1<response>
Result=13.1<submit>
What is 12.1-20?
<request><SimpleCalculatorTool>12.1-20<call>-7.9<response>
Result=-7.9<submit>"""
```
3. Create a `trl.TextEnvironment` with the model
```python
env = TextEnvironment(
model,
tokenizer,
{"SimpleCalculatorTool": tool_fn},
reward_fn,
prompt,
generation_kwargs=generation_kwargs,
)
```
4. Then generate some data such as `tasks = ["\n\nWhat is 13.1-3?", "\n\nWhat is 4*3?"]` and run the environment with `queries, responses, masks, rewards, histories = env.run(tasks)`. The environment will look for the `<call>` token in the prompt and append the tool output to the response; it will also return the mask associated with the response. You can further use the `histories` to visualize the interaction between the model and the tool; `histories[0].show_text()` will show the text with color-coded tool output and `histories[0].show_tokens(tokenizer)` will show visualize the tokens.
![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/learning_tools.png)
1. Finally, we can train the model with `train_stats = ppo_trainer.step(queries, responses, rewards, masks)`. The trainer will use the mask to ignore the tool output when computing the loss, make sure to pass that argument to `step`.
## Experiment results
We trained a model with the above script for 10 random seeds. You can reproduce the run with the following command. Feel free to remove the `--slurm-*` arguments if you don't have access to a slurm cluster.
```
WANDB_TAGS="calculator_final" python benchmark/benchmark.py \
--command "python examples/calculator_few_shots_env.py" \
--num-seeds 10 \
--start-seed 1 \
--workers 10 \
--slurm-gpus-per-task 1 \
--slurm-ntasks 1 \
--slurm-total-cpus 8 \
--slurm-template-path benchmark/trl.slurm_template
```
We can then use [`openrlbenchmark`](https://github.com/openrlbenchmark/openrlbenchmark) which generates the following plot.
```
python -m openrlbenchmark.rlops_multi_metrics \
--filters '?we=openrlbenchmark&wpn=trl&xaxis=_step&ceik=trl_ppo_trainer_config.value.tracker_project_name&cen=trl_ppo_trainer_config.value.log_with&metrics=env/reward_mean&metrics=objective/kl' \
'wandb?tag=calculator_final&cl=calculator_mask' \
--env-ids trl \
--check-empty-runs \
--pc.ncols 2 \
--pc.ncols-legend 1 \
--output-filename static/0compare \
--scan-history
```
![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/learning_tools_chart.png)
As we can see, while 1-2 experiments crashed for some reason, most of the runs obtained near perfect proficiency in the calculator task.
## (Early Experiments 🧪): learning to use a wiki tool for question answering
In the [ToolFormer](https://arxiv.org/abs/2302.04761) paper, it shows an interesting use case that utilizes a Wikipedia Search tool to help answer questions. In this section, we attempt to perform similar experiments but uses RL instead to teach the model to use a wiki tool on the [TriviaQA](https://nlp.cs.washington.edu/triviaqa/) dataset.
<Tip warning={true}>
**Note that many settings are different so the results are not directly comparable.**
</Tip>
### Building a search index
Since [ToolFormer](https://arxiv.org/abs/2302.04761) did not open source, we needed to first replicate the search index. It is mentioned in their paper that the authors built the search index using a BM25 retriever that indexes the Wikipedia dump from [KILT](https://github.com/facebookresearch/KILT)
Fortunately, [`pyserini`](https://github.com/castorini/pyserini) already implements the BM25 retriever and provides a prebuilt index for the KILT Wikipedia dump. We can use the following code to search the index.
```python
from pyserini.search.lucene import LuceneSearcher
import json
searcher = LuceneSearcher.from_prebuilt_index('wikipedia-kilt-doc')
def search(query):
hits = searcher.search(query, k=1)
hit = hits[0]
contents = json.loads(hit.raw)['contents']
return contents
print(search("tennis racket"))
```
```
Racket (sports equipment)
A racket or racquet is a sports implement consisting of a handled frame with an open hoop across which a network of strings or catgut is stretched tightly. It is used for striking a ball or shuttlecock in games such as squash, tennis, racquetball, and badminton. Collectively, these games are known as racket sports. Racket design and manufacturing has changed considerably over the centuries.
The frame of rackets for all sports was traditionally made of solid wood (later laminated wood) and the strings of animal intestine known as catgut. The traditional racket size was limited by the strength and weight of the wooden frame which had to be strong enough to hold the strings and stiff enough to hit the ball or shuttle. Manufacturers started adding non-wood laminates to wood rackets to improve stiffness. Non-wood rackets were made first of steel, then of aluminum, and then carbon fiber composites. Wood is still used for real tennis, rackets, and xare. Most rackets are now made of composite materials including carbon fiber or fiberglass, metals such as titanium alloys, or ceramics.
...
```
We then basically deployed this snippet as a Hugging Face space [here](https://huggingface.co/spaces/vwxyzjn/pyserini-wikipedia-kilt-doc), so that we can use the space as a `transformers.Tool` later.
![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/pyserini.png)
### Experiment settings
We use the following settings:
* use the `bigcode/starcoderbase` model as the base model
* use the `pyserini-wikipedia-kilt-doc` space as the wiki tool and only uses the first paragrahs of the search result, allowing the `TextEnvironment` to obtain at most `max_tool_reponse=400` response tokens from the tool.
* test if the response contain the answer string, if so, give a reward of 1, otherwise, give a reward of 0.
* notice this is a simplified evaluation criteria. In [ToolFormer](https://arxiv.org/abs/2302.04761), the authors checks if the first 20 words of the response contain the correct answer.
* used the following prompt that demonstrates the usage of the wiki tool.
```python
prompt = """\
Answer the following question:
Q: In which branch of the arts is Patricia Neary famous?
A: Ballets
A2: <request><Wiki>Patricia Neary<call>Patricia Neary (born October 27, 1942) is an American ballerina, choreographer and ballet director, who has been particularly active in Switzerland. She has also been a highly successful ambassador for the Balanchine Trust, bringing George Balanchine's ballets to 60 cities around the globe.<response>
Result=Ballets<submit>
Q: Who won Super Bowl XX?
A: Chicago Bears
A2: <request><Wiki>Super Bowl XX<call>Super Bowl XX was an American football game between the National Football Conference (NFC) champion Chicago Bears and the American Football Conference (AFC) champion New England Patriots to decide the National Football League (NFL) champion for the 1985 season. The Bears defeated the Patriots by the score of 4610, capturing their first NFL championship (and Chicago's first overall sports victory) since 1963, three years prior to the birth of the Super Bowl. Super Bowl XX was played on January 26, 1986 at the Louisiana Superdome in New Orleans.<response>
Result=Chicago Bears<submit>
Q: """
```
### Result and Discussion
Our experiments show that the agent can learn to use the wiki tool to answer questions. The learning curves would go up mostly, but one of the experiment did crash.
![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/triviaqa_learning_curves.png)
Wandb report is [here](https://wandb.ai/costa-huang/cleanRL/reports/TriviaQA-Final-Experiments--Vmlldzo1MjY0ODk5) for further inspection.
Note that the correct rate of the trained model is on the low end, which could be due to the following reasons:
* **incorrect searches:** When given the question `"What is Bruce Willis' real first name?"` if the model searches for `Bruce Willis`, our wiki tool returns "Patrick Poivey (born 18 February 1948) is a French actor. He is especially known for his voice: he is the French dub voice of Bruce Willis since 1988.` But a correct search should be `Walter Bruce Willis (born March 19, 1955) is an American former actor. He achieved fame with a leading role on the comedy-drama series Moonlighting (19851989) and appeared in over a hundred films, gaining recognition as an action hero after his portrayal of John McClane in the Die Hard franchise (19882013) and other roles.[1][2]"
![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/real_first_name.png)
* **unnecessarily long response**: The wiki tool by default sometimes output very long sequences. E.g., when the wiki tool searches for "Brown Act"
* Our wiki tool returns "The Ralph M. Brown Act, located at California Government Code 54950 "et seq.", is an act of the California State Legislature, authored by Assemblymember Ralph M. Brown and passed in 1953, that guarantees the public's right to attend and participate in meetings of local legislative bodies."
* [ToolFormer](https://arxiv.org/abs/2302.04761)'s wiki tool returns "The Ralph M. Brown Act is an act of the California State Legislature that guarantees the public's right to attend and participate in meetings of local legislative bodies." which is more succinct.
![](https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/brown_act.png)
## (Early Experiments 🧪): solving math puzzles with python interpreter
In this section, we attempt to teach the model to use a python interpreter to solve math puzzles. The rough idea is to give the agent a prompt like the following:
```python
prompt = """\
Example of using a Python API to solve math questions.
Q: Olivia has $23. She bought five bagels for $3 each. How much money does she have left?
<request><PythonInterpreter>
def solution():
money_initial = 23
bagels = 5
bagel_cost = 3
money_spent = bagels * bagel_cost
money_left = money_initial - money_spent
result = money_left
return result
print(solution())
<call>72<response>
Result = 72 <submit>
Q: """
Training results TBD.

View File

@ -14,16 +14,62 @@ If you want to log with tensorboard, add the kwarg `project_kwargs={"logging_dir
## PPO Logging
Here's a brief explanation for the logged metrics provided in the data:
Key metrics to monitor. We want to maximize the reward, maintain a low KL divergence, and maximize entropy:
1. `env/reward_mean`: The average reward obtained from the environment. Alias `ppo/mean_scores`, which is sed to specifically monitor the reward model.
1. `env/reward_std`: The standard deviation of the reward obtained from the environment. Alias ``ppo/std_scores`, which is sed to specifically monitor the reward model.
1. `env/reward_dist`: The histogram distribution of the reward obtained from the environment.
1. `objective/kl`: The mean Kullback-Leibler (KL) divergence between the old and new policies. It measures how much the new policy deviates from the old policy. The KL divergence is used to compute the KL penalty in the objective function.
1. `objective/kl_dist`: The histogram distribution of the `objective/kl`.
1. `objective/kl_coef`: The coefficient for Kullback-Leibler (KL) divergence in the objective function.
1. `ppo/mean_non_score_reward`: The **KL penalty** calculated by `objective/kl * objective/kl_coef` as the total reward for optimization to prevent the new policy from deviating too far from the old policy.
1. `objective/entropy`: The entropy of the model's policy, calculated by `-logprobs.sum(-1).mean()`. High entropy means the model's actions are more random, which can be beneficial for exploration.
Training stats:
1. `ppo/learning_rate`: The learning rate for the PPO algorithm.
1. `ppo/policy/entropy`: The entropy of the model's policy, calculated by `pd = torch.nn.functional.softmax(logits, dim=-1); entropy = torch.logsumexp(logits, dim=-1) - torch.sum(pd * logits, dim=-1)`. It measures the randomness of the policy.
1. `ppo/policy/clipfrac`: The fraction of probability ratios (old policy / new policy) that fell outside the clipping range in the PPO objective. This can be used to monitor the optimization process.
1. `ppo/policy/approxkl`: The approximate KL divergence between the old and new policies, measured by `0.5 * masked_mean((logprobs - old_logprobs) ** 2, mask)`, corresponding to the `k2` estimator in http://joschu.net/blog/kl-approx.html
1. `ppo/policy/policykl`: Similar to `ppo/policy/approxkl`, but measured by `masked_mean(old_logprobs - logprobs, mask)`, corresponding to the `k1` estimator in http://joschu.net/blog/kl-approx.html
1. `ppo/policy/ratio`: The histogram distribution of the ratio between the new and old policies, used to compute the PPO objective.
1. `ppo/policy/advantages_mean`: The average of the GAE (Generalized Advantage Estimation) advantage estimates. The advantage function measures how much better an action is compared to the average action at a state.
1. `ppo/policy/advantages`: The histogram distribution of `ppo/policy/advantages_mean`.
1. `ppo/returns/mean`: The mean of the TD(λ) returns, calculated by `returns = advantage + values`, another indicator of model performance. See https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/ for more details.
1. `ppo/returns/var`: The variance of the TD(λ) returns, calculated by `returns = advantage + values`, another indicator of model performance.
1. `ppo/val/mean`: The mean of the values, used to monitor the value function's performance.
1. `ppo/val/var` : The variance of the values, used to monitor the value function's performance.
1. `ppo/val/var_explained`: The explained variance for the value function, used to monitor the value function's performance.
1. `ppo/val/clipfrac`: The fraction of the value function's predicted values that are clipped.
1. `ppo/val/vpred`: The predicted values from the value function.
1. `ppo/val/error`: The mean squared error between the `ppo/val/vpred` and returns, used to monitor the value function's performance.
1. `ppo/loss/policy`: The policy loss for the Proximal Policy Optimization (PPO) algorithm.
1. `ppo/loss/value`: The loss for the value function in the PPO algorithm. This value quantifies how well the function estimates the expected future rewards.
1. `ppo/loss/total`: The total loss for the PPO algorithm. It is the sum of the policy loss and the value function loss.
Stats on queries, responses, and logprobs:
1. `tokens/queries_len_mean`: The average length of the queries tokens.
1. `tokens/queries_len_std`: The standard deviation of the length of the queries tokens.
1. `tokens/queries_dist`: The histogram distribution of the length of the queries tokens.
1. `tokens/responses_len_mean`: The average length of the responses tokens.
1. `tokens/responses_len_std`: The standard deviation of the length of the responses tokens.
1. `tokens/responses_dist`: The histogram distribution of the length of the responses tokens. (Costa: inconsistent naming, should be `tokens/responses_len_dist`)
1. `objective/logprobs`: The histogram distribution of the log probabilities of the actions taken by the model.
1. `objective/ref_logprobs`: The histogram distribution of the log probabilities of the actions taken by the reference model.
### Crucial values
During training, many values are logged, here are the most important ones:
1. `env/reward_mean`,`env/reward_std`, `env/reward_dist`: the properties of the reward distribution from the "environment".
2. `ppo/mean_scores`: The mean scores directly out of the reward model.
3. `ppo/mean_non_score_reward`: The mean negated KL penalty during training (shows the delta between the reference model and the new policy over the batch in the step)
1. `env/reward_mean`,`env/reward_std`, `env/reward_dist`: the properties of the reward distribution from the "environment" / reward model
1. `ppo/mean_non_score_reward`: The mean negated KL penalty during training (shows the delta between the reference model and the new policy over the batch in the step)
### Training stability parameters:
Here are some parameters that are useful to monitor for stability (when these diverge or collapse to 0, try tuning variables):
1. `ppo/loss/value`: The value function loss -- will spike / NaN when not going well.
2. `ppo/val/clipfrac`: The fraction of clipped values in the value function loss. This is often from 0.3 to 0.6.
3. `objective/kl_coef`: The target coefficient with [`AdaptiveKLController`]. Often increases before numerical instabilities.
1. `ppo/loss/value`: it will spike / NaN when not going well.
1. `ppo/policy/ratio`: `ratio` being 1 is a baseline value, meaning that the probability of sampling a token is the same under the new and old policy. If the ratio is too high like 200, it means the probability of sampling a token is 200 times higher under the new policy than the old policy. This is a sign that the new policy is too different from the old policy, which will likely cause overoptimization and collapse training later on.
1. `ppo/policy/clipfrac` and `ppo/policy/approxkl`: if `ratio` is too high, the `ratio` is going to get clipped, resulting in high `clipfrac` and high `approxkl` as well.
1. `objective/kl`: it should stay positive so that the policy is not too far away from the reference policy.
1. `objective/kl_coef`: The target coefficient with [`AdaptiveKLController`]. Often increases before numerical instabilities.

View File

@ -3,13 +3,13 @@
The notebooks and scripts in this examples show how to use Low Rank Adaptation (LoRA) to fine-tune models in a memory efficient manner. Most of PEFT methods supported in peft library but note that some PEFT methods such as Prompt tuning are not supported.
For more information on LoRA, see the [original paper](https://arxiv.org/abs/2106.09685).
Here's an overview of the `peft`-enabled notebooks and scripts in the [trl repository](https://github.com/lvwerra/trl/tree/main/examples):
Here's an overview of the `peft`-enabled notebooks and scripts in the [trl repository](https://github.com/huggingface/trl/tree/main/examples):
| File | Task | Description | Colab link |
|---|---| --- |
| [`stack_llama/rl_training.py`](https://github.com/lvwerra/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py) | RLHF | Distributed fine-tuning of the 7b parameter LLaMA models with a learned reward model and `peft`. | |
| [`stack_llama/reward_modeling.py`](https://github.com/lvwerra/trl/blob/main/examples/research_projects/stack_llama/scripts/reward_modeling.py) | Reward Modeling | Distributed training of the 7b parameter LLaMA reward model with `peft`. | |
| [`stack_llama/supervised_finetuning.py`](https://github.com/lvwerra/trl/blob/main/examples/research_projects/stack_llama/scripts/supervised_finetuning.py) | SFT | Distributed instruction/supervised fine-tuning of the 7b parameter LLaMA model with `peft`. | |
| [`stack_llama/rl_training.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py) | RLHF | Distributed fine-tuning of the 7b parameter LLaMA models with a learned reward model and `peft`. | |
| [`stack_llama/reward_modeling.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama/scripts/reward_modeling.py) | Reward Modeling | Distributed training of the 7b parameter LLaMA reward model with `peft`. | |
| [`stack_llama/supervised_finetuning.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama/scripts/supervised_finetuning.py) | SFT | Distributed instruction/supervised fine-tuning of the 7b parameter LLaMA model with `peft`. | |
## Installation
Note: peft is in active development, so we install directly from their Github page.

View File

@ -11,10 +11,10 @@ You just need to install `peft` and optionally install `bitsandbytes` as well if
You need to address this approach in three stages that we summarize as follows:
1- Train a base model on the target domain (e.g. `imdb` dataset) - this is the Supervised Fine Tuning stage - it can leverage the `SFTTrainer` from TRL.
2- Train a reward model using `peft`. This is required in order to re-use the adapter during the RL optimisation process (step 3 below). We show an example of leveraging the `RewardTrainer` from TRL in [this example](https://github.com/lvwerra/trl/tree/main/examples/scripts/reward_trainer.py)
2- Train a reward model using `peft`. This is required in order to re-use the adapter during the RL optimisation process (step 3 below). We show an example of leveraging the `RewardTrainer` from TRL in [this example](https://github.com/huggingface/trl/tree/main/examples/scripts/reward_trainer.py)
3- Fine tune new adapters on the base model using PPO and the reward adapter. ("0 abstraction RL")
Make sure to use the same model (i.e. same architecure and same weights) for the stages 2 & 3.
Make sure to use the same model (i.e. same architecture and same weights) for the stages 2 & 3.
## Quickstart

View File

@ -2,7 +2,7 @@
TRL supports custom reward modeling for anyone to perform reward modeling on their dataset and model.
Check out a complete flexible example inside [`examples/scripts`](https://github.com/lvwerra/trl/tree/main/examples/scripts/reward_trainer.py) folder.
Check out a complete flexible example inside [`examples/scripts`](https://github.com/huggingface/trl/tree/main/examples/scripts/reward_trainer.py) folder.
## Expected dataset format
@ -23,7 +23,7 @@ The `j` and `k` suffixes are used to denote the two sentences in the paired data
## Using the `RewardTrainer`
After standardizing your dataset, you can use the `RewardTrainer` as a classic HugingFace Trainer.
After standardizing your dataset, you can use the `RewardTrainer` as a classic Hugging Face Trainer.
You should pass an `AutoModelForSequenceClassification` model to the `RewardTrainer`.
### Leveraging the `peft` library to train a reward model

View File

@ -2,15 +2,15 @@
The notebooks and scripts in this examples show how to fine-tune a model with a sentiment classifier (such as `lvwerra/distilbert-imdb`).
Here's an overview of the notebooks and scripts in the [trl repository](https://github.com/lvwerra/trl/tree/main/examples):
Here's an overview of the notebooks and scripts in the [trl repository](https://github.com/huggingface/trl/tree/main/examples):
| File | Description | Colab link |
|---|---| --- |
| [`gpt2-sentiment.ipynb`](https://github.com/lvwerra/trl/blob/main/examples/notebooks/gpt2-sentiment.ipynb) | Fine-tune GPT2 to generate positive movie reviews. | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lvwerra/trl/blob/main/examples/sentiment/notebooks/gpt2-sentiment.ipynb)
| [`gpt2-sentiment.ipynb`](https://github.com/huggingface/trl/blob/main/examples/notebooks/gpt2-sentiment.ipynb) | Fine-tune GPT2 to generate positive movie reviews. | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/sentiment/notebooks/gpt2-sentiment.ipynb)
|
| [`gpt2-sentiment-control.ipynb`](https://github.com/lvwerra/trl/blob/main/examples/notebooks/gpt2-sentiment-control.ipynb) | Fine-tune GPT2 to generate movie reviews with controlled sentiment. | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/lvwerra/trl/blob/main/examples/sentiment/notebooks/gpt2-sentiment-control.ipynb)
| [`gpt2-sentiment-control.ipynb`](https://github.com/huggingface/trl/blob/main/examples/notebooks/gpt2-sentiment-control.ipynb) | Fine-tune GPT2 to generate movie reviews with controlled sentiment. | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/sentiment/notebooks/gpt2-sentiment-control.ipynb)
|
| [`gpt2-sentiment.py`](https://github.com/lvwerra/trl/blob/main/examples/ppo_trainer/sentiment_tuning.py) | Same as the notebook, but easier to use to use in multi-GPU setup with any architecture. | x |
| [`gpt2-sentiment.py`](https://github.com/huggingface/trl/blob/main/examples/ppo_trainer/sentiment_tuning.py) | Same as the notebook, but easier to use to use in multi-GPU setup with any architecture. | x |
## Installation

View File

@ -2,7 +2,7 @@
Supervised fine-tuning (or SFT for short) is a crucial step in RLHF. In TRL we provide an easy-to-use API to create your SFT models and train them with few lines of code on your dataset.
Check out a complete flexible example inside [`examples/scripts`](https://github.com/lvwerra/trl/tree/main/examples/scripts/sft_trainer.py) folder.
Check out a complete flexible example inside [`examples/scripts`](https://github.com/huggingface/trl/tree/main/examples/scripts/sft_trainer.py) folder.
## Quickstart
@ -46,7 +46,7 @@ trainer = SFTTrainer(
trainer.train()
```
The above snippets will use the default training arguments from the [`transformers.TrainingArguments`](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments) class. If you want to modify that, make sure to create your own `TrainingArguments` object and pass it to the [`SFTTrainer`] constructor as it is done on the [`supervised_finetuning.py` script](https://github.com/lvwerra/trl/blob/main/examples/stack_llama/scripts/supervised_finetuning.py) on the stack-llama example.
The above snippets will use the default training arguments from the [`transformers.TrainingArguments`](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments) class. If you want to modify that, make sure to create your own `TrainingArguments` object and pass it to the [`SFTTrainer`] constructor as it is done on the [`supervised_finetuning.py` script](https://github.com/huggingface/trl/blob/main/examples/stack_llama/scripts/supervised_finetuning.py) on the stack-llama example.
## Advanced usage
@ -111,6 +111,47 @@ trainer = SFTTrainer(
trainer.train()
```
#### Using token_ids directly for `response_template`
Some tokenizers like Llama 2 (`meta-llama/Llama-2-XXb-hf`) tokenize sequences differently depending whether they have context or not. For example:
```python
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
def print_tokens_with_ids(txt):
tokens = tokenizer.tokenize(txt, add_special_tokens=False)
token_ids = tokenizer.encode(txt, add_special_tokens=False)
print(list(zip(tokens, token_ids)))
prompt = """### User: Hello\n\n### Assistant: Hi, how can I help you?"""
print_tokens_with_ids(prompt) # [..., ('▁Hello', 15043), ('<0x0A>', 13), ('<0x0A>', 13), ('##', 2277), ('#', 29937), ('▁Ass', 4007), ('istant', 22137), (':', 29901), ...]
response_template = "### Assistant:"
print_tokens_with_ids(response_template) # [('▁###', 835), ('▁Ass', 4007), ('istant', 22137), (':', 29901)]
```
In this case, and due to lack of context in `response_template`, the same string ("### Assistant:") is tokenized differently:
- Text (with context): `[2277, 29937, 4007, 22137, 29901]`
- `response_template` (without context): `[835, 4007, 22137, 29901]`
This will lead to an error when the `DataCollatorForCompletionOnlyLM` does not find the `response_template` in the dataset example text:
```
RuntimeError: Could not find response key [835, 4007, 22137, 29901] in token IDs tensor([ 1, 835, ...])
```
To solve this, you can tokenize the `response_template` with the same context than in the dataset, truncate it as needed and pass the `token_ids` directly to the `response_template` argument of the `DataCollatorForCompletionOnlyLM` class. For example:
```python
response_template_with_context = "\n### Assistant:" # We added context here: "\n". This is enough for this tokenizer
response_template_ids = tokenizer.encode(response_template_with_context, add_special_tokens=False)[2:] # Now we have it like in the dataset texts: `[2277, 29937, 4007, 22137, 29901]`
data_collator = DataCollatorForCompletionOnlyLM(response_template_ids, tokenizer=tokenizer)
```
### Format your input prompts
For instruction fine-tuning, it is quite common to have two columns inside the dataset: one for the prompt & the other for the response.
@ -142,7 +183,7 @@ trainer = SFTTrainer(
trainer.train()
```
To preperly format your input make sure to process all the examples by looping over them and returning a list of processed text. Check out a full example on how to use SFTTrainer on alpaca dataset [here](https://github.com/lvwerra/trl/pull/444#issue-1760952763)
To preperly format your input make sure to process all the examples by looping over them and returning a list of processed text. Check out a full example on how to use SFTTrainer on alpaca dataset [here](https://github.com/huggingface/trl/pull/444#issue-1760952763)
### Packing dataset ([`ConstantLengthDataset`])
@ -185,7 +226,7 @@ You can also customize the [`ConstantLengthDataset`] much more by directly passi
### Control over the pretrained model
You can directly pass the kwargs of the `from_pretrained()` method to the [`SFTTrainer`]. For example, if you want to load a model in a different precision, analoguous to
You can directly pass the kwargs of the `from_pretrained()` method to the [`SFTTrainer`]. For example, if you want to load a model in a different precision, analogous to
```python
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=torch.bfloat16)

View File

@ -0,0 +1,197 @@
# Text Environments
Text environments provide a learning ground for language agents. It allows a language model to use tools to accomplish a task such as using a Python interpreter to answer math questions or using a search index for trivia questions. Having access to tools allows language models to solve tasks that would be very hard for the models itself but can be trivial for the appropriate tools. A good example is arithmetics of large numbers that become a simple copy-paste task once you have access to a calculator.
<div style="text-align: center">
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/textenv.png">
</div>
Let's dive into how text environments work and start with tools!
## Tools
One of the core building blocks of text environments are tools that the model can use to solve tasks. In general tools can be any Python function that takes a string as input and returns string. The `TextEnvironment` offers two options for tools: either go with predefined tools from `transformers.Tool` or define your own function or class with `__call__` method. Let's have a look at both!
### `transformers.Tool`
Text environments fully support tools of the class `transformers.Tool`. The advantage of building tools in that framework is that they can easily be shared
```Python
from transformers import load_tool
# simple calculator tool that runs +-/* operations
calc_tool = load_tool("ybelkada/simple-calculator")
# python interpreter that executes program and returns outputs
py_tool = load_tool("lvwerra/python-interpreter")
# wikipedia search index that returns best search match
wiki_tool = load_tool("vwxyzjn/pyserini-wikipedia-kilt-doc")
```
These tools are either loaded from the hub or from a local folder. Using the tool is as simple as calling them with a text query:
```Python
calc_tool("1/2")
>>> "0.5"
```
Note that both input and return values are strings to enable easy usage with a language model.
### Custom Tools
The following is an example of a tool that adds two integers:
```Python
def add(text):
int_1, int_2 = text.split("+")
result = int(int_1) + int(int_2)
return str(result)
print(add("1+1"))
>>> "2"
```
We looked at basic examples such as a calculator but the principle holds for more complex tools as well such as a web search tool where you input the query and get the search results in return. Now let's look at how the model can use the tools with the call syntax.
### Call syntax
In order to have a unified way for the model to call a tool we created a simple syntax that looks as follows:
```python
"<request><TOOL_NAME>QUERY<call>TOOL_RESPONSE<response>"
```
There are a few special tokens involved so let's decompose it: First the model can signal that it wants to use a tool by emitting the `<request>` token. After that we want to know the name of the tool to call which is done by enclosing the tool name with `<>` brackets. Once we know which tool to call the tool query follows which is in free text form. The `<call>` tokens signifies the end of the query and stops the model generation. At this point the model output is parsed and the query sent to the tool. The environment appends the tool response to the string followed by the `<response>` token to show the end the tool output.
Let's look at the concrete example of the calculator and assume its name is `Calculator` (more on how the name of a tool is inferred later):
```python
"<request><Calculator>1/2<call>0.5<response>"
```
Finally, the episode is ended and generation stops when the model generates `<submit>` which marks the interaction as completed.
Now let's have a look how we can create a new text environment!
## Create a `TextEnvironment`
```python
prompt = """\
What is 13-3?
<request><SimpleCalculatorTool>13-3<call>10.0<response>
Result=10<submit>
"""
def reward_fn(result, answer):
"""Simplified reward function returning 1 if result matches answer and 0 otherwise."""
result_parsed = result.split("=")[1].split("<")[0]
return int(result_parsed==answer)
text_env = TextEnvironemnt(
model=model,
tokenizer=tokenizer,
tools= {"SimpleCalculatorTool": load_tool("ybelkada/simple-calculator")},
reward_fn=exact_match_reward,
prompt=prompt,
max_turns=1
max_tool_response=100
generation_kwargs={"do_sample": "true"}
)
```
Let's decompose the settings:
| Argument | Description |
|:-------------------|:----------------|
| `model` | Language model to interact with the environment and generate requests. |
| `tokenizer` | Tokenizer of language model handling tokenization of strings. |
| `tools` | `list` of `dict` of tools. If former the name of the tool is inferred from class name and otherwise it's the keys of the dictionary.|
| `reward_fn` | A function that takes a string as input and returns. Can have extra arguments that are passed to `.run()` such as ground truth.|
| `prompt` | Prompt to prepend to every task. Usually a few examples to demonstrate to the model how to use the tools in a few-shot fashion. |
| `max_turns` | Maximum number of interactions between model and tools before episode ends.|
| `max_tool_response`| The tool response is truncated to this number to avoid running out of model context.|
| `max_length` | The maximum number of tokens to allow in an episode. |
| `generation_kwargs`| Generation settings used by the language model. |
You can customize the environment to your needs and add custom tools and settings. Let's see how you can use the environment to have the model interact with the available tools!
## Run an Episode
To run a set of queries through the text environment one can simply use the `run` method.
```python
queries = ["What is 1/2?"]
answers = ["0.5"]
queries, responses, masks, rewards, histories = text_env.run(queries, answers=answers)
```
This will execute the model/tool feedback loop for each query until either no tool is called anymore, the maximum number of turns is reached or to maximum number of tokens in an episode is exceeded. The extra `kwargs` (e.g. `answers=answers` above) passed to `run` will be passed on to the reward function.
There are five objects that are returned by `run`:
- `queries`: a list of the tokenized queries
- `responses`: all tokens that have been generated withing the environment including model and tool tokens
- `masks`: mask that indicates which tokens have been generated by the model and which tokens are generated by the tool
- `rewards`: a list of reward for each query/response
- `histories`: list of `TextHistory` objects, which are useful objects containing all the above and also the text equivalents
The masks are crucial for training as we don't want to optimize tokens that the model has not generated which are tokens produced by the tools.
Next, we'll train a PPO step with the generated responses!
### Train
Training on episodes from the `TextEnvironment` is straight forward and simply requires forwarding all the returned variables except the `TextHistory` objects to the `step` method:
```python
train_stats = ppo_trainer.step(queries, responses, rewards, masks)
```
## `TextHistory`
The `TextHistory` object stores the interactions between the model and the text environment. It stores tokens and text generated in each turn and their source in each turn (model or system) as well as rewards. Let's go through the class attributes and methods.
### Attributes
The following table summarises the available attributes of the `TextEnvironment` class:
| Attribute | Description |
|:-------------------|:----------------|
| `text` | The full string of the text generated in the text environment with both model and system generated text. |
| `text_spans` | A list of tuples with the spans for each model or system generated text segment. |
| `system_spans` | A list of boolean values indicating if the segment is model or system generated. |
| `tokens` | All tokens generated in text environment with both model and system generated tokens. |
| `token_spans` | Similar to `text_spans` the `token_spans` indicate the boundaries of model andsystem generated tokens. |
| `token_masks` | The token masks can be used to ignore system generated tokens by masking them. |
| `completed` | Indicates if the interaction with the environment has completed. |
| `truncated` | Indicates if the interaction with the environment has completed because max length was reached. |
With these attributes you can reconstruct every interaction of the model with the `TextEnvironment`. The `TextHistory` also lets you visualize the text history. Let's have a look!
### Visualization
When the model interacts inside the `TextEnvironment` it can be useful to visualize and separate which parts of the text outputs were generated by the model and which parts come from the system and tools. For that purpose there are the two methods [`TextHistory.show_text`] and [`TextHistory.show_tokens`]. They print the text and tokens respectively and highlight the various segments using the [`rich` libray](https://github.com/Textualize/rich) (make sure to install it before using these methods).
You can see that the prompt is highlighted in gray, whereas system segments such as query and tool responses are highlighted in green. All segments generated by the model are highlighted in blue and in addition to the pure text output the reward is displayed as additional text in plum. Here an example of `show_text`:
<div style="text-align: center">
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/textenv_show_text.png" width=600>
</div>
Sometimes there can be tricky tokenization related issues that are hidden when showing the decoded text. Thus `TextHistory` also offers an option to display the same highlighting on the tokens directly with `show_tokens`:
<div style="text-align: center">
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/textenv_show_tokens.png" width=800>
</div>
Note that you can turn on the colour legend by passing `show_legend=True`.
## API Documentation
[[autodoc]] TextEnvironment
[[autodoc]] TextHistory

View File

@ -24,6 +24,14 @@ We also support a `RewardTrainer` that can be used to train a reward model.
[[autodoc]] DPOTrainer
## DDPOConfig
[[autodoc]] DDPOConfig
## DDPOTrainer
[[autodoc]] DDPOTrainer
## set_seed
[[autodoc]] set_seed

58
docs/source/use_model.md Normal file
View File

@ -0,0 +1,58 @@
# Use model after training
Once you have trained a model using either the SFTTrainer, PPOTrainer, or DPOTrainer, you will have a fine-tuned model that can be used for text generation. In this section, we'll walk through the process of loading the fine-tuned model and generating text. If you need to run an inference server with the trained model, you can explore libraries such as [`text-generation-inference`](https://github.com/huggingface/text-generation-inference).
## Load and Generate
If you have fine-tuned a model fully, meaning without the use of PEFT you can simply load it like any other language model in transformers. E.g. the value head that was trained during the PPO training is no longer needed and if you load the model with the original transformer class it will be ignored:
```python
from transformers import AutoTokenizer, AutoModelForCausalLM
model_name_or_path = "kashif/stack-llama-2" #path/to/your/model/or/name/on/hub
device = "cpu" # or "cuda" if you have a GPU
model = AutoModelForCausalLM.from_pretrained(model_name_or_path).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
inputs = tokenizer.encode("This movie was really", return_tensors="pt").to(device)
outputs = model.generate(inputs)
print(tokenizer.decode(outputs[0]))
```
Alternatively you can also use the pipeline:
```python
from transformers import pipeline
model_name_or_path = "kashif/stack-llama-2" #path/to/your/model/or/name/on/hub
pipe = pipeline("text-generation", model=model_name_or_path)
print(pipe("This movie was really")[0]["generated_text"])
```
## Use Adapters PEFT
```python
from peft import PeftConfig, PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
base_model_name = "kashif/stack-llama-2" #path/to/your/model/or/name/on/hub"
adapter_model_name = "path/to/my/adapter"
model = AutoModelForCausalLM.from_pretrained(base_model_name)
model = PeftModel.from_pretrained(model, adapter_model_name)
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
```
You can also merge the adapters into the base model so you can use the model like a normal transformers model, however the checkpoint will be significantly bigger:
```python
model = AutoModelForCausalLM.from_pretrained(base_model_name)
model = PeftModel.from_pretrained(model, adapter_model_name)
model = model.merge_and_unload()
model.save_pretrained("merged_adapters")
```
Once you have the model loaded and either merged the adapters or keep them separately on top you can run generation as with a normal model outlined above.

View File

@ -157,4 +157,4 @@ for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
ppo_trainer.log_stats(stats, batch, rewards)
```
For the rest of the details adn evaluation, please refer to our [blog post on StackLLaMA](https://huggingface.co/blog/stackllama).
For the rest of the details and evaluation, please refer to our [blog post on StackLLaMA](https://huggingface.co/blog/stackllama).

View File

@ -29,7 +29,7 @@ pip install trl
pip install wandb
```
Note: if you don't want to log with `wandb` remove `log_with="wandb"` in the scripts/notebooks.
You can also replace it with your favourite experiment tracker that's [supported by `accelerate`](https://huggingface.co/docs/accelerate/usage_guides/tracking).
You can also replace it with your favorite experiment tracker that's [supported by `accelerate`](https://huggingface.co/docs/accelerate/usage_guides/tracking).
## Accelerate Config
For all the examples, you'll need to generate an `Accelerate` config with:
@ -41,10 +41,10 @@ accelerate config # will prompt you to define the training configuration
Then, it is encouraged to launch jobs with `accelerate launch`!
## Categories
The examples are currently split over the following categories:
The examples are currently split into the following categories:
**1: [ppo_trainer](https://github.com/lvwerra/trl/tree/main/examples/scripts/sentiment_tuning.py)**: Learn about different ways of using PPOTrainer
**2: [sft_trainer](https://github.com/lvwerra/trl/tree/main/examples/scripts/sft_trainer.py)**: Learn about how to leverage `SFTTrainer` for supervised fine-tuning your pretrained language models easily.
**3: [reward_modeling](https://github.com/lvwerra/trl/tree/main/examples/scripts/reward_trainer.py)**: Learn about how to use `RewardTrainer` to easily train your own reward model to use it for your RLHF pipeline.
**4: [research_projects](https://github.com/lvwerra/trl/tree/main/examples/research_projects)**: Check out that folder to check the scripts used for some research projects that used TRL (LM de-toxification, Stack-Llama, etc.)
**5: [notebooks](https://github.com/lvwerra/trl/tree/main/examples/notebooks)**: Check out that folder to check some applications of TRL features directly on a Jupyter notebook. This includes running sentiment tuning and sentiment control on a notebook, but also how to use "Best of N sampling" strategy using TRL.
1. **[ppo_trainer](https://github.com/huggingface/trl/tree/main/examples/scripts/sentiment_tuning.py)**: Learn about different ways of using PPOTrainer
1. **[sft_trainer](https://github.com/huggingface/trl/tree/main/examples/scripts/sft_trainer.py)**: Learn about how to leverage `SFTTrainer` for supervised fine-tuning your pretrained language models easily.
1. **[reward_modeling](https://github.com/huggingface/trl/tree/main/examples/scripts/reward_trainer.py)**: Learn about how to use `RewardTrainer` to easily train your own reward model to use it for your RLHF pipeline.
1. **[research_projects](https://github.com/huggingface/trl/tree/main/examples/research_projects)**: Check out this folder to find the scripts used for some research projects that used TRL (LM de-toxification, Stack-Llama, etc.)
1. **[notebooks](https://github.com/huggingface/trl/tree/main/examples/notebooks)**: Check out this folder for some applications of TRL features directly on a Jupyter notebook. This includes running sentiment tuning and sentiment control on a notebook and how to use the "Best of N sampling" strategy using TRL.

View File

@ -101,6 +101,7 @@ if __name__ == "__main__":
]
model_ref = AutoModelForCausalLM.from_pretrained(script_args.model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(script_args.model_name_or_path)
tokenizer.pad_token = tokenizer.eos_token

View File

@ -2,6 +2,6 @@
This directory contains a collection of Jupyter notebooks that demonstrate how to use the TRL library in different applications.
- [`best_of_n.ipynb`](https://github.com/lvwerra/trl/tree/main/examples/notebooks/best_of_n.ipynb): This notebook demonstrates how to use the "Best of N" sampling strategy using TRL when fine-tuning your model with PPO.
- [`gpt2-sentiment.ipynb`](https://github.com/lvwerra/trl/tree/main/examples/notebooks/gpt2-sentiment.ipynb): This notebook demonstrates how to reproduce the GPT2 imdb sentiment tuning example on a jupyter notebook.
- [`gpt2-control.ipynb`](https://github.com/lvwerra/trl/tree/main/examples/notebooks/gpt2-control.ipynb): This notebook demonstrates how to reproduce the GPT2 sentiment control exampel on a jupyter notebook.
- [`best_of_n.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/best_of_n.ipynb): This notebook demonstrates how to use the "Best of N" sampling strategy using TRL when fine-tuning your model with PPO.
- [`gpt2-sentiment.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-sentiment.ipynb): This notebook demonstrates how to reproduce the GPT2 imdb sentiment tuning example on a jupyter notebook.
- [`gpt2-control.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-control.ipynb): This notebook demonstrates how to reproduce the GPT2 sentiment control exampel on a jupyter notebook.

View File

@ -847,7 +847,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
"version": "3.9.12"
},
"vscode": {
"interpreter": {

View File

@ -398,7 +398,7 @@
"metadata": {},
"source": [
"### Training progress\n",
"If you are tracking the training progress with Weights&Biases you should see a plot similar to the one below. Check out the interactive sample report on wandb.ai: [link](https://app.wandb.ai/lvwerra/trl-showcase/runs/1jtvxb1m/).\n",
"If you are tracking the training progress with Weights&Biases you should see a plot similar to the one below. Check out the interactive sample report on wandb.ai: [link](https://app.wandb.ai/huggingface/trl-showcase/runs/1jtvxb1m/).\n",
"\n",
"<div style=\"text-align: center\">\n",
"<img src='https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/gpt2_tuning_progress.png' width='800'>\n",

View File

@ -1,6 +1,7 @@
# Research projects that uses TRL
# Research projects that use TRL
Welcome to the research projects folder! Here you can find the scripts used for some research projects that used TRL and maintained by the developpers and the community (LM de-toxification, Stack-Llama, etc.). Check out the READMEs in the subfolders for more information!
Welcome to the research projects folder! Here you can find the scripts used for some research projects that used TRL and maintained by the developers and the community (LM de-toxification, Stack-Llama, etc.). Check out the READMEs in the subfolders for more information!
- [De-detoxifying language models](https://github.com/lvwerra/trl/tree/main/examples/research_projects/toxicity)
- [Stack-Llama](https://github.com/lvwerra/trl/tree/main/examples/research_projects/stack_llama)
- [De-detoxifying language models](https://github.com/huggingface/trl/tree/main/examples/research_projects/toxicity)
- [Stack-Llama](https://github.com/huggingface/trl/tree/main/examples/research_projects/stack_llama)
- [Stack-Llama-2](https://github.com/huggingface/trl/tree/main/examples/research_projects/stack_llama_2)

View File

@ -0,0 +1,51 @@
# DPO pipeline for the creation of StackLlaMa 2: a Stack exchange llama-v2-7b model
## Prerequisites
Install all the dependencies in the `requirements.txt`:
```
$ pip install -U -r requirements.txt
```
Since we will use `accelerate` for training, make sure to run:
```
$ accelerate config
```
## Training
There were two main steps to the DPO training process:
1. Supervised fine-tuning of the base llama-v2-7b model to create llama-v2-7b-se:
- `accelerate launch examples/stack_llama_2/scripts/sft_llama2.py --output_dir="sft"`
1. Run the DPO trainer using the model saved by the previous step:
- `accelerate launch examples/stack_llama_2/scripts/dpo_llama2.py --model_name_or_path="sft/final_checkpoint" --output_dir="dpo"`
## Merging the adaptors
To merge the adaptors into the base model we can use the `merge_peft_adapter.py` helper script that comes with TRL:
```
python trl/examples/research_projects/stack_llama/scripts/merge_peft_adapter.py --base_model_name="meta-llama/Llama-2-7b-hf" --adapter_model_name="dpo/final_checkpoint/" --output_name="stack-llama-2"
```
which will also push the model to your HuggingFace hub account.
## Running the model
We can load the DPO-trained LoRA adaptors which were saved by the DPO training step and load them via:
```py
from peft import AutoPeftModelForCausalLM
model = AutoPeftModelForCausalLM.from_pretrained(
"dpo/final_checkpoint",
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
load_in_4bit=True,
)
model.generate(...)
```

View File

@ -0,0 +1,223 @@
# 0. imports
import os
from dataclasses import dataclass, field
from typing import Dict, Optional
import torch
from datasets import Dataset, load_dataset
from peft import AutoPeftModelForCausalLM, LoraConfig
from transformers import AutoTokenizer, HfArgumentParser, TrainingArguments
from trl import DPOTrainer
# Define and parse arguments.
@dataclass
class ScriptArguments:
"""
The arguments for the DPO training script.
"""
# data parameters
beta: Optional[float] = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"})
# training parameters
model_name_or_path: Optional[str] = field(
default="../sft/results/final_checkpoint",
metadata={"help": "the location of the SFT model name or path"},
)
learning_rate: Optional[float] = field(default=5e-4, metadata={"help": "optimizer learning rate"})
lr_scheduler_type: Optional[str] = field(default="cosine", metadata={"help": "the lr scheduler type"})
warmup_steps: Optional[int] = field(default=100, metadata={"help": "the number of warmup steps"})
weight_decay: Optional[float] = field(default=0.05, metadata={"help": "the weight decay"})
optimizer_type: Optional[str] = field(default="paged_adamw_32bit", metadata={"help": "the optimizer type"})
per_device_train_batch_size: Optional[int] = field(default=4, metadata={"help": "train batch size per device"})
per_device_eval_batch_size: Optional[int] = field(default=1, metadata={"help": "eval batch size per device"})
gradient_accumulation_steps: Optional[int] = field(
default=4, metadata={"help": "the number of gradient accumulation steps"}
)
gradient_checkpointing: Optional[bool] = field(
default=True, metadata={"help": "whether to use gradient checkpointing"}
)
lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"})
lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "the lora dropout parameter"})
lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"})
max_prompt_length: Optional[int] = field(default=512, metadata={"help": "the maximum prompt length"})
max_length: Optional[int] = field(default=1024, metadata={"help": "the maximum sequence length"})
max_steps: Optional[int] = field(default=1000, metadata={"help": "max number of training steps"})
logging_steps: Optional[int] = field(default=10, metadata={"help": "the logging frequency"})
save_steps: Optional[int] = field(default=100, metadata={"help": "the saving frequency"})
eval_steps: Optional[int] = field(default=100, metadata={"help": "the evaluation frequency"})
output_dir: Optional[str] = field(default="./results", metadata={"help": "the output directory"})
log_freq: Optional[int] = field(default=1, metadata={"help": "the logging frequency"})
# instrumentation
sanity_check: Optional[bool] = field(default=False, metadata={"help": "only train on 1000 samples"})
report_to: Optional[str] = field(
default="wandb",
metadata={
"help": 'The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,'
'`"comet_ml"`, `"mlflow"`, `"neptune"`, `"tensorboard"`,`"clearml"` and `"wandb"`. '
'Use `"all"` to report to all integrations installed, `"none"` for no integrations.'
},
)
# debug argument for distributed training
ignore_bias_buffers: Optional[bool] = field(
default=False,
metadata={
"help": "fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See"
"https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992"
},
)
def get_stack_exchange_paired(
data_dir: str = "data/rl",
sanity_check: bool = False,
cache_dir: str = None,
num_proc=24,
) -> Dataset:
"""Load the stack-exchange-paired dataset from Hugging Face and convert it to the necessary format.
The dataset is converted to a dictionary with the following structure:
{
'prompt': List[str],
'chosen': List[str],
'rejected': List[str],
}
Prompts are structured as follows:
"Question: " + <prompt> + "\n\nAnswer: "
"""
dataset = load_dataset(
"lvwerra/stack-exchange-paired",
split="train",
cache_dir=cache_dir,
data_dir=data_dir,
)
original_columns = dataset.column_names
if sanity_check:
dataset = dataset.select(range(min(len(dataset), 1000)))
def return_prompt_and_responses(samples) -> Dict[str, str]:
return {
"prompt": ["Question: " + question + "\n\nAnswer: " for question in samples["question"]],
"chosen": samples["response_j"],
"rejected": samples["response_k"],
}
return dataset.map(
return_prompt_and_responses,
batched=True,
num_proc=num_proc,
remove_columns=original_columns,
)
if __name__ == "__main__":
parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]
# 1. load a pretrained model
model = AutoPeftModelForCausalLM.from_pretrained(
script_args.model_name_or_path,
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
load_in_4bit=True,
)
model.config.use_cache = False
if script_args.ignore_bias_buffers:
# torch distributed hack
model._ddp_params_and_buffers_to_ignore = [
name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool
]
model_ref = AutoPeftModelForCausalLM.from_pretrained(
script_args.model_name_or_path,
low_cpu_mem_usage=True,
torch_dtype=torch.float16,
load_in_4bit=True,
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
tokenizer.pad_token = tokenizer.eos_token
# 2. Load the Stack-exchange paired dataset
train_dataset = get_stack_exchange_paired(data_dir="data/rl", sanity_check=script_args.sanity_check)
train_dataset = train_dataset.filter(
lambda x: len(x["prompt"]) + len(x["chosen"]) <= script_args.max_length
and len(x["prompt"]) + len(x["rejected"]) <= script_args.max_length
)
# 3. Load evaluation dataset
eval_dataset = get_stack_exchange_paired(data_dir="data/evaluation", sanity_check=True)
eval_dataset = eval_dataset.filter(
lambda x: len(x["prompt"]) + len(x["chosen"]) <= script_args.max_length
and len(x["prompt"]) + len(x["rejected"]) <= script_args.max_length
)
# 4. initialize training arguments:
training_args = TrainingArguments(
per_device_train_batch_size=script_args.per_device_train_batch_size,
per_device_eval_batch_size=script_args.per_device_eval_batch_size,
max_steps=script_args.max_steps,
logging_steps=script_args.logging_steps,
save_steps=script_args.save_steps,
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
gradient_checkpointing=script_args.gradient_checkpointing,
learning_rate=script_args.learning_rate,
evaluation_strategy="steps",
eval_steps=script_args.eval_steps,
output_dir=script_args.output_dir,
report_to=script_args.report_to,
lr_scheduler_type=script_args.lr_scheduler_type,
warmup_steps=script_args.warmup_steps,
optim=script_args.optimizer_type,
bf16=True,
remove_unused_columns=False,
run_name="dpo_llama2",
)
peft_config = LoraConfig(
r=script_args.lora_r,
lora_alpha=script_args.lora_alpha,
lora_dropout=script_args.lora_dropout,
target_modules=[
"q_proj",
"v_proj",
"k_proj",
"out_proj",
"fc_in",
"fc_out",
"wte",
],
bias="none",
task_type="CAUSAL_LM",
)
# 5. initialize the DPO trainer
dpo_trainer = DPOTrainer(
model,
model_ref,
args=training_args,
beta=script_args.beta,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
tokenizer=tokenizer,
peft_config=peft_config,
max_prompt_length=script_args.max_prompt_length,
max_length=script_args.max_length,
)
# 6. train
dpo_trainer.train()
dpo_trainer.save_model(script_args.output_dir)
# 7. save
output_dir = os.path.join(script_args.output_dir, "final_checkpoint")
dpo_trainer.model.save_pretrained(output_dir)

View File

@ -0,0 +1,7 @@
transformers
trl
peft
accelerate
datasets
bitsandbytes
wandb

View File

@ -0,0 +1,216 @@
# Fine-Tune Llama2-7b on SE paired dataset
import os
from dataclasses import dataclass, field
from typing import Optional
import torch
from datasets import load_dataset
from peft import AutoPeftModelForCausalLM, LoraConfig
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, HfArgumentParser, TrainingArguments
from trl import SFTTrainer
from trl.trainer import ConstantLengthDataset
@dataclass
class ScriptArguments:
model_name: Optional[str] = field(default="meta-llama/Llama-2-7b-hf", metadata={"help": "the model name"})
log_with: Optional[str] = field(default="wandb", metadata={"help": "use 'wandb' to log with wandb"})
dataset_name: Optional[str] = field(default="lvwerra/stack-exchange-paired", metadata={"help": "the dataset name"})
subset: Optional[str] = field(default="data/finetune", metadata={"help": "the subset to use"})
split: Optional[str] = field(default="train", metadata={"help": "the split to use"})
size_valid_set: Optional[int] = field(default=4000, metadata={"help": "the size of the validation set"})
streaming: Optional[bool] = field(default=True, metadata={"help": "whether to stream the dataset"})
shuffle_buffer: Optional[int] = field(default=5000, metadata={"help": "the shuffle buffer size"})
seq_length: Optional[int] = field(default=1024, metadata={"help": "the sequence length"})
num_workers: Optional[int] = field(default=4, metadata={"help": "the number of workers"})
max_steps: Optional[int] = field(default=500, metadata={"help": "the maximum number of sgd steps"})
logging_steps: Optional[int] = field(default=10, metadata={"help": "the logging frequency"})
save_steps: Optional[int] = field(default=10, metadata={"help": "the saving frequency"})
per_device_train_batch_size: Optional[int] = field(default=4, metadata={"help": "the per device train batch size"})
per_device_eval_batch_size: Optional[int] = field(default=1, metadata={"help": "the per device eval batch size"})
gradient_accumulation_steps: Optional[int] = field(default=2, metadata={"help": "the gradient accumulation steps"})
gradient_checkpointing: Optional[bool] = field(
default=True, metadata={"help": "whether to use gradient checkpointing"}
)
group_by_length: Optional[bool] = field(default=False, metadata={"help": "whether to group by length"})
packing: Optional[bool] = field(default=True, metadata={"help": "whether to use packing for SFTTrainer"})
lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"})
lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "the lora dropout parameter"})
lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"})
learning_rate: Optional[float] = field(default=1e-4, metadata={"help": "the learning rate"})
lr_scheduler_type: Optional[str] = field(default="cosine", metadata={"help": "the lr scheduler type"})
num_warmup_steps: Optional[int] = field(default=100, metadata={"help": "the number of warmup steps"})
weight_decay: Optional[float] = field(default=0.05, metadata={"help": "the weight decay"})
optimizer_type: Optional[str] = field(default="paged_adamw_32bit", metadata={"help": "the optimizer type"})
output_dir: Optional[str] = field(default="./results", metadata={"help": "the output directory"})
log_freq: Optional[int] = field(default=1, metadata={"help": "the logging frequency"})
parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]
if script_args.group_by_length and script_args.packing:
raise ValueError("Cannot use both packing and group by length")
def chars_token_ratio(dataset, tokenizer, nb_examples=400):
"""
Estimate the average number of characters per token in the dataset.
"""
total_characters, total_tokens = 0, 0
for _, example in tqdm(zip(range(nb_examples), iter(dataset)), total=nb_examples):
text = prepare_sample_text(example)
total_characters += len(text)
if tokenizer.is_fast:
total_tokens += len(tokenizer(text).tokens())
else:
total_tokens += len(tokenizer.tokenize(text))
return total_characters / total_tokens
def print_trainable_parameters(model):
"""
Prints the number of trainable parameters in the model.
"""
trainable_params = 0
all_param = 0
for _, param in model.named_parameters():
all_param += param.numel()
if param.requires_grad:
trainable_params += param.numel()
print(
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
)
def prepare_sample_text(example):
"""Prepare the text from a sample of the dataset."""
text = f"Question: {example['question']}\n\nAnswer: {example['response_j']}"
return text
def create_datasets(tokenizer, args):
dataset = load_dataset(
args.dataset_name,
data_dir=args.subset,
split=args.split,
use_auth_token=True,
num_proc=args.num_workers if not args.streaming else None,
streaming=args.streaming,
)
if args.streaming:
print("Loading the dataset in streaming mode")
valid_data = dataset.take(args.size_valid_set)
train_data = dataset.skip(args.size_valid_set)
train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=None)
else:
dataset = dataset.train_test_split(test_size=0.005, seed=None)
train_data = dataset["train"]
valid_data = dataset["test"]
print(f"Size of the train set: {len(train_data)}. Size of the validation set: {len(valid_data)}")
chars_per_token = chars_token_ratio(train_data, tokenizer)
print(f"The character to token ratio of the dataset is: {chars_per_token:.2f}")
train_dataset = ConstantLengthDataset(
tokenizer,
train_data,
formatting_func=prepare_sample_text,
infinite=True,
seq_length=args.seq_length,
chars_per_token=chars_per_token,
)
valid_dataset = ConstantLengthDataset(
tokenizer,
valid_data,
formatting_func=prepare_sample_text,
infinite=False,
seq_length=args.seq_length,
chars_per_token=chars_per_token,
)
return train_dataset, valid_dataset
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
base_model = AutoModelForCausalLM.from_pretrained(
script_args.model_name,
quantization_config=bnb_config,
device_map={"": 0},
trust_remote_code=True,
use_auth_token=True,
)
base_model.config.use_cache = False
peft_config = LoraConfig(
r=script_args.lora_r,
lora_alpha=script_args.lora_alpha,
lora_dropout=script_args.lora_dropout,
target_modules=["q_proj", "v_proj"],
bias="none",
task_type="CAUSAL_LM",
)
tokenizer = AutoTokenizer.from_pretrained(script_args.model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right" # Fix weird overflow issue with fp16 training
training_args = TrainingArguments(
output_dir=script_args.output_dir,
per_device_train_batch_size=script_args.per_device_train_batch_size,
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
per_device_eval_batch_size=script_args.per_device_eval_batch_size,
learning_rate=script_args.learning_rate,
logging_steps=script_args.logging_steps,
max_steps=script_args.max_steps,
report_to=script_args.log_with,
save_steps=script_args.save_steps,
group_by_length=script_args.group_by_length,
lr_scheduler_type=script_args.lr_scheduler_type,
warmup_steps=script_args.num_warmup_steps,
optim=script_args.optimizer_type,
bf16=True,
remove_unused_columns=False,
run_name="sft_llama2",
)
train_dataset, eval_dataset = create_datasets(tokenizer, script_args)
trainer = SFTTrainer(
model=base_model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
peft_config=peft_config,
packing=script_args.packing,
max_seq_length=None,
tokenizer=tokenizer,
args=training_args,
)
trainer.train()
trainer.save_model(script_args.output_dir)
output_dir = os.path.join(script_args.output_dir, "final_checkpoint")
trainer.model.save_pretrained(output_dir)
# Free memory for merging weights
del base_model
torch.cuda.empty_cache()
model = AutoPeftModelForCausalLM.from_pretrained(output_dir, device_map="auto", torch_dtype=torch.bfloat16)
model = model.merge_and_unload()
output_merged_dir = os.path.join(script_args.output_dir, "final_merged_checkpoint")
model.save_pretrained(output_merged_dir, safe_serialization=True)

View File

@ -0,0 +1,119 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
import numpy as np
import torch
from transformers import AutoTokenizer, load_tool
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, TextEnvironment
def generate_data(n):
"""Generate random arithmetic tasks and answers."""
tasks, answers = [], []
for _ in range(n):
a = np.random.randint(0, 50)
b = np.random.randint(0, 50)
op = np.random.choice(["-", "+", "*"])
tasks.append(f"\n\nWhat is {a} {op} {b}?")
if op == "-":
answers.append(a - b)
elif op == "+":
answers.append(a + b)
else:
answers.append(a * b)
return tasks, answers
def exact_match_reward(responses, answers=None):
"""Reward if generated response contains correct answer."""
rewards = []
pattern = r"Result\s*=\s*(-?\d+(?:\.\d+)?)\s*<submit>" # generated by chatGPT
for response, answer in zip(responses, answers):
reward = 0.0
predicted_number = None
match_pattern = re.findall(pattern, response)
if match_pattern:
predicted_number = float(match_pattern[0])
if predicted_number is not None:
if np.abs(predicted_number - answer) < 0.01:
reward += 1.0
rewards.append(torch.tensor(reward))
return rewards
# set up models
model_id = "gpt2"
model = AutoModelForCausalLMWithValueHead.from_pretrained(model_id)
model_ref = AutoModelForCausalLMWithValueHead.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
# system prompt
prompt = """\
What is 13-3?
<request><SimpleCalculatorTool>13-3<call>10.0<response>
Result=10<submit>
What is 4*3?
<request><SimpleCalculatorTool>4*3<call>12.0<response>
Result=12<submit>"""
generation_kwargs = {
"min_length": -1,
"top_k": 0.0,
"top_p": 1.0,
"do_sample": True,
"pad_token_id": tokenizer.eos_token_id,
"eos_token_id": -1,
"max_new_tokens": 32,
}
# trainer
ppo_config = PPOConfig(
batch_size=256,
learning_rate=1.41e-5,
mini_batch_size=64,
log_with="wandb",
)
ppo_trainer = PPOTrainer(ppo_config, model, model_ref, tokenizer)
# text env
text_env = TextEnvironment(
model,
tokenizer,
{"SimpleCalculatorTool": load_tool("ybelkada/simple-calculator")},
exact_match_reward,
prompt,
generation_kwargs=generation_kwargs,
)
# main training loop
for step in range(100):
tasks, answers = generate_data(ppo_config.batch_size)
queries, responses, masks, rewards, histories = text_env.run(tasks, answers=answers)
train_stats = ppo_trainer.step(queries, responses, rewards, masks)
response_texts = [tokenizer.decode(response) for response in responses]
query_texts = [tokenizer.decode(query) for query in queries]
texts = {"query": [qt.split("<submit>")[-1].strip() for qt in query_texts], "response": response_texts}
ppo_trainer.log_stats(train_stats, texts, rewards)
ppo_trainer.save_pretrained(model_id + "-calculator")

View File

@ -0,0 +1,194 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
from dataclasses import dataclass, field
from typing import Optional
import numpy as np
import torch
from datasets import load_dataset
from peft import LoraConfig
from transformers import AutoTokenizer, HfArgumentParser, load_tool
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, TextEnvironment
os.environ["HF_ALLOW_CODE_EVAL"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@dataclass
class ScriptArguments:
model_name: Optional[str] = field(default="bigcode/starcoderbase", metadata={"help": "the model name"})
learning_rate: Optional[float] = field(default=1e-5, metadata={"help": "the learning rate"})
mini_batch_size: Optional[int] = field(default=1, metadata={"help": "the PPO minibatch size"})
batch_size: Optional[int] = field(default=32, metadata={"help": "the batch size"})
gradient_accumulation_steps: Optional[int] = field(
default=16, metadata={"help": "the number of gradient accumulation steps"}
)
max_new_tokens: Optional[int] = field(default=256, metadata={"help": "max number of generated tokens per turn"})
ppo_epochs: Optional[int] = field(default=1, metadata={"help": "max number of ppo epochs"})
n_epochs: Optional[int] = field(default=32, metadata={"help": "max number of ppo epochs"})
parser = HfArgumentParser(ScriptArguments)
args = parser.parse_args_into_dataclasses()[0]
def exact_match_reward(responses, answers=None):
"""Reward if generated response contains correct answer."""
rewards = []
pattern = r"Result\s*=\s*(-?\d+(?:\.\d+)?)\s*<submit>" # generated by chatGPT
for response, answer in zip(responses, answers):
reward = 0.0
try:
predicted_number = None
match_pattern = re.findall(pattern, response)
if match_pattern:
predicted_number = float(match_pattern[0])
if predicted_number is not None:
if np.abs((predicted_number - float(answer))) < 0.1:
reward += 1.0
except: # noqa
pass
rewards.append(torch.tensor(reward))
return rewards
def evaluate(test_dataloader, text_env, ppo_trainer):
test_rewards = []
for test_batch in test_dataloader:
_, _, _, rewards, _ = text_env.run(test_batch["query"], answers=test_batch["answer"])
test_rewards.extend(rewards)
test_rewards = ppo_trainer.accelerator.gather_for_metrics(
torch.stack(test_rewards).to(ppo_trainer.accelerator.device)
)
return test_rewards.mean()
lora_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules=["c_proj", "c_attn", "q_attn"],
)
# set up models
model = AutoModelForCausalLMWithValueHead.from_pretrained(
args.model_name,
use_auth_token=True,
load_in_4bit=True,
peft_config=lora_config,
)
tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_auth_token=True)
tokenizer.pad_token = tokenizer.eos_token
ds = load_dataset("gsm8k", "main", split="train")
ds = ds.rename_columns({"question": "query"})
ds = ds.map(lambda x: {"answer": x["answer"].split("#### ")[1]})
ds = ds.select(range(1, len(ds))) # skip the first sample which is used in prompt
ds_test = load_dataset("gsm8k", "main", split="test")
ds_test = ds_test.rename_columns({"question": "query"})
ds_test = ds_test.map(lambda x: {"answer": x["answer"].split("#### ")[1]})
test_dataloader = torch.utils.data.DataLoader(ds_test, batch_size=args.batch_size)
# prompt
prompt = """\
Example of using a Python API to solve math questions.
Q: Olivia has $23. She bought five bagels for $3 each. How much money does she have left?
<request><PythonInterpreter>
def solution():
money_initial = 23
bagels = 5
bagel_cost = 3
money_spent = bagels * bagel_cost
money_left = money_initial - money_spent
result = money_left
return result
print(solution())
<call>72<response>
Result = 72 <submit>
Q: """
generation_kwargs = {
"min_length": -1,
"top_k": 0.0,
"top_p": 1.0,
"do_sample": True,
"pad_token_id": tokenizer.eos_token_id,
"eos_token_id": -1,
"max_new_tokens": args.max_new_tokens,
}
# trainer
ppo_config = PPOConfig(
batch_size=args.batch_size,
learning_rate=args.learning_rate,
mini_batch_size=args.mini_batch_size,
ppo_epochs=args.ppo_epochs,
gradient_accumulation_steps=args.gradient_accumulation_steps,
log_with="wandb",
tracker_project_name="trl-gsm8k",
remove_unused_columns=False,
optimize_cuda_cache=True,
)
ppo_trainer = PPOTrainer(config=ppo_config, model=model, tokenizer=tokenizer, dataset=ds)
test_dataloader = ppo_trainer.accelerator.prepare(test_dataloader)
# text env
text_env = TextEnvironment(
model,
tokenizer,
[load_tool("lvwerra/python-interpreter")],
exact_match_reward,
prompt,
max_turns=2,
generation_kwargs=generation_kwargs,
)
# main training loop
for epoch in range(args.n_epochs):
for step, batch in enumerate(ppo_trainer.dataloader):
if (step == 0) and (epoch % 4 == 0): # evaluate every 4 epochs
reward_mean_test = evaluate(test_dataloader, text_env, ppo_trainer)
else:
reward_mean_test = None
queries, responses, masks, rewards, histories = text_env.run(batch["query"], answers=batch["answer"])
train_stats = ppo_trainer.step(queries, responses, rewards, masks)
# logging
if reward_mean_test is not None:
train_stats["env/reward_mean_test"] = reward_mean_test
texts = {
"query": batch["query"],
"response": [tokenizer.decode(response) for response in responses],
"answer": batch["answer"],
}
ppo_trainer.log_stats(train_stats, texts, rewards)
reward_mean_test = evaluate(test_dataloader, text_env, ppo_trainer)
ppo_trainer.save_pretrained(f"model/{args.model_name}-gsm8k")

View File

@ -0,0 +1,189 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from dataclasses import dataclass, field
from typing import Optional
import torch
from datasets import load_dataset
from peft import LoraConfig
from transformers import AutoTokenizer, HfArgumentParser, load_tool
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, TextEnvironment
os.environ["HF_ALLOW_CODE_EVAL"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@dataclass
class ScriptArguments:
model_name: Optional[str] = field(default="bigcode/starcoderbase", metadata={"help": "the model name"})
log_with: Optional[str] = field(default=None, metadata={"help": "use 'wandb' to log with wandb"})
learning_rate: Optional[float] = field(default=1e-5, metadata={"help": "the learning rate"})
mini_batch_size: Optional[int] = field(default=1, metadata={"help": "the PPO minibatch size"})
batch_size: Optional[int] = field(default=32, metadata={"help": "the batch size"})
gradient_accumulation_steps: Optional[int] = field(
default=16, metadata={"help": "the number of gradient accumulation steps"}
)
max_new_tokens: Optional[int] = field(default=256, metadata={"help": "max number of generated tokens per turn"})
ppo_epochs: Optional[int] = field(default=1, metadata={"help": "max number of ppo epochs"})
iterations: Optional[int] = field(default=1000, metadata={"help": "the number of iterations"})
seed: Optional[int] = field(default=0, metadata={"help": "the random seed"})
parser = HfArgumentParser(ScriptArguments)
args = parser.parse_args_into_dataclasses()[0]
lora_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules=["c_proj", "c_attn", "q_attn"],
)
# set up models
model = AutoModelForCausalLMWithValueHead.from_pretrained(
args.model_name,
use_auth_token=True,
trust_remote_code=True,
load_in_4bit=True,
peft_config=lora_config,
)
tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_auth_token=True)
tokenizer.pad_token = tokenizer.eos_token
# system prompt
prompt = """\
Answer the following question:
Q: In which branch of the arts is Patricia Neary famous?
A: Ballets
A2: <request><Wiki>Patricia Neary<call>Patricia Neary (born October 27, 1942) is an American ballerina, choreographer and ballet director, who has been particularly active in Switzerland. She has also been a highly successful ambassador for the Balanchine Trust, bringing George Balanchine's ballets to 60 cities around the globe.<response>
Result=Ballets<submit>
Q: Who won Super Bowl XX?
A: Chicago Bears
A2: <request><Wiki>Super Bowl XX<call>Super Bowl XX was an American football game between the National Football Conference (NFC) champion Chicago Bears and the American Football Conference (AFC) champion New England Patriots to decide the National Football League (NFL) champion for the 1985 season. The Bears defeated the Patriots by the score of 4610, capturing their first NFL championship (and Chicago's first overall sports victory) since 1963, three years prior to the birth of the Super Bowl. Super Bowl XX was played on January 26, 1986 at the Louisiana Superdome in New Orleans.<response>
Result=Chicago Bears<submit>
Q: """
generation_kwargs = {
"min_length": -1,
"top_k": 0.0,
"top_p": 1.0,
"do_sample": True,
"pad_token_id": tokenizer.eos_token_id,
"eos_token_id": -1,
"max_new_tokens": args.max_new_tokens,
}
# trainer
config = PPOConfig(
batch_size=args.batch_size,
model_name=args.model_name,
learning_rate=args.learning_rate,
log_with=args.log_with,
mini_batch_size=args.mini_batch_size,
ppo_epochs=args.ppo_epochs,
gradient_accumulation_steps=args.gradient_accumulation_steps,
seed=args.seed,
optimize_cuda_cache=True,
)
ppo_trainer = PPOTrainer(config=config, model=model, tokenizer=tokenizer)
dataset = load_dataset("trivia_qa", "rc", split="train")
local_seed = args.seed + ppo_trainer.accelerator.process_index * 100003 # Prime
dataset = dataset.shuffle(local_seed)
def data_generator():
for i in range(len(dataset)):
yield dataset[i]["question"], [item for item in dataset[i]["answer"]["normalized_aliases"]]
gen = data_generator()
gen = iter(gen)
def generate_data(n):
tasks, answers = [], []
for i in range(n):
q, a = next(gen)
tasks.append(q)
answers.append(a)
return tasks, answers
def exact_match_reward(responses, answers=None):
"""Reward if generated response contains correct answer."""
rewards = []
for response, answer in zip(responses, answers):
reward = 0.0
for a in answer:
if a.lower() in response.lower():
reward += 1.0
break
rewards.append(torch.tensor(reward))
return rewards
# text env
tool = load_tool("vwxyzjn/pyserini-wikipedia-kilt-doc")
# limit the amount if tokens
tool_fn = lambda x: tool(x).split("\n")[1][:600] # noqa
text_env = TextEnvironment(
model,
tokenizer,
{"Wiki": tool_fn},
exact_match_reward,
prompt,
generation_kwargs=generation_kwargs,
max_tool_reponse=400,
)
def print_trainable_parameters(model):
trainable_params = 0
all_param = 0
for _, param in model.named_parameters():
all_param += param.numel()
if param.requires_grad:
trainable_params += param.numel()
print(
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
)
print_trainable_parameters(model)
# main training loop
for i in range(args.iterations):
tasks, answers = generate_data(config.batch_size)
queries, responses, masks, rewards, histories = text_env.run(tasks, answers=answers)
train_stats = ppo_trainer.step(queries, responses, rewards, masks)
response_texts = [tokenizer.decode(response) for response in responses]
query_texts = [tokenizer.decode(query) for query in queries]
texts = {
"query": [qt.split("<submit>")[-1].strip() for qt in query_texts],
"response": response_texts,
"answer": [", ".join(item) for item in answers],
}
all_rewards = ppo_trainer.accelerator.gather(torch.tensor(rewards, device=ppo_trainer.accelerator.device))
ppo_trainer.log_stats(train_stats, texts, [item for item in all_rewards])
if i % 100 == 0:
ppo_trainer.save_pretrained(f"models/{args.model_name}_{args.seed}_{i}_triviaqa")

View File

@ -1,8 +1,9 @@
# Maintained scripts
This folder shows multiple ways to use the objects from TRL such as `SFTTrainer`, `RewardTrainer` and `PPOTrainer` in different scenarios.
This folder shows multiple ways to use the objects from TRL such as `SFTTrainer`, `RewardTrainer`, `DDPOTrainer` and `PPOTrainer` in different scenarios.
- `sft_trainer.py`: This script shows how to use the SFTTrainer to fine tune a model or adapters into a target dataset.
- `reward_trainer.py`: This script shows how to use the RewardTrainer to train a reward model on your own dataset.
- `sentiment_tuning.py`: This script shows how to use the PPOTrainer to fine-tune a sentiment analysis model using IMDB dataset
- `multi_adapter_rl.py`: This script shows how to use the PPOTrainer to train a single base model with multiple adapters. This scripts requires you to run the example script with the reward model training beforehand.
- `multi_adapter_rl.py`: This script shows how to use the PPOTrainer to train a single base model with multiple adapters. This scripts requires you to run the example script with the reward model training beforehand.
- `stable_diffusion_tuning_example.py`: This script shows to use DDPOTrainer to fine-tune a stable diffusion model using reinforcement learning.

View File

@ -0,0 +1,150 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from typing import Optional
import torch
from datasets import load_dataset
from peft import LoraConfig
from tqdm import tqdm
from transformers import BitsAndBytesConfig, HfArgumentParser, LlamaTokenizer
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer
from trl.core import LengthSampler
input_min_text_length = 6
input_max_text_length = 12
@dataclass
class ScriptArguments:
"""
The name of the Casual LM model we wish to fine with PPO
"""
model_name: Optional[str] = field(default="huggyllama/llama-7b", metadata={"help": "the model name"})
dataset_name: Optional[str] = field(default="Anthropic/hh-rlhf", metadata={"help": "the dataset name"})
rm_adapter: Optional[str] = field(
default="trl-lib/llama-7b-hh-rm-adapter", metadata={"help": "the rm adapter name"}
)
log_with: Optional[str] = field(default=None, metadata={"help": "use 'wandb' to log with wandb"})
use_safetensors: Optional[bool] = field(default=False, metadata={"help": "Use safetensors"})
seed: Optional[int] = field(default=0, metadata={"help": "the random seed"})
use_score_scaling: Optional[bool] = field(default=False, metadata={"help": "Use score scaling"})
use_score_norm: Optional[bool] = field(
default=False, metadata={"help": "Use score normalization. Only applicable if use_score_scaling is True"}
)
score_clip: Optional[float] = field(default=None, metadata={"help": "Score clipping"})
parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]
def create_and_prepare_dataset(tokenizer):
dataset = load_dataset(script_args.dataset_name, split="train[:1%]")
input_size = LengthSampler(input_min_text_length, input_max_text_length)
def tokenize(example):
text_size = input_size()
example["input_ids"] = tokenizer.encode(example["chosen"])[:text_size]
example["query"] = tokenizer.decode(example["input_ids"])
return example
dataset = dataset.map(tokenize, batched=False)
dataset.set_format("torch")
return dataset
lora_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
nf4_config = BitsAndBytesConfig(
load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16
)
model = AutoModelForCausalLMWithValueHead.from_pretrained(
script_args.model_name,
device_map={"": 0},
peft_config=lora_config,
quantization_config=nf4_config,
reward_adapter=script_args.rm_adapter,
use_safetensors=script_args.use_safetensors,
)
tokenizer = LlamaTokenizer.from_pretrained(script_args.model_name)
tokenizer.pad_token = tokenizer.eos_token
dataset = create_and_prepare_dataset(tokenizer)
def collator(data):
return dict((key, [d[key] for d in data]) for key in data[0])
config = PPOConfig(
model_name=script_args.model_name,
log_with=script_args.log_with,
learning_rate=1e-5,
batch_size=8,
mini_batch_size=2,
gradient_accumulation_steps=2,
optimize_cuda_cache=True,
seed=script_args.seed,
use_score_scaling=script_args.use_score_scaling,
use_score_norm=script_args.use_score_norm,
score_clip=script_args.score_clip,
)
ppo_trainer = PPOTrainer(
config,
model,
ref_model=None,
tokenizer=tokenizer,
dataset=dataset,
data_collator=collator,
)
generation_kwargs = {
"top_k": 0.0,
"top_p": 0.9,
"do_sample": True,
"pad_token_id": tokenizer.pad_token_id,
}
for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
question_tensors = batch["input_ids"]
response_tensors = ppo_trainer.generate(
question_tensors,
return_prompt=False,
**generation_kwargs,
)
batch["response"] = tokenizer.batch_decode(response_tensors, skip_special_tokens=True)
# Compute reward score
texts = [q + r for q, r in zip(batch["query"], batch["response"])]
inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt").to(ppo_trainer.accelerator.device)
raw_rewards = ppo_trainer.model.compute_reward_score(**inputs)
rewards = [raw_rewards[i, -1, 1] for i in range(len(raw_rewards))] # take last token
# Run PPO step
stats = ppo_trainer.step(question_tensors, response_tensors, rewards)
ppo_trainer.log_stats(stats, batch, rewards)

View File

@ -45,7 +45,6 @@ class ScriptArguments:
default=1, metadata={"help": "the number of gradient accumulation steps"}
)
early_stopping: Optional[bool] = field(default=False, metadata={"help": "whether to early stop"})
target_kl: Optional[float] = field(default=6, metadata={"help": "kl target for early stopping"})
use_peft: Optional[bool] = field(default=False, metadata={"help": "whether to use peft"})
use_seq2seq: Optional[bool] = field(default=False, metadata={"help": "whether to use seq2seq models"})
kl_penalty: Optional[str] = field(
@ -56,6 +55,11 @@ class ScriptArguments:
)
target_kl: Optional[float] = field(default=0.1, metadata={"help": "kl target for early stopping"})
seed: Optional[int] = field(default=0, metadata={"help": "the random seed"})
use_score_scaling: Optional[bool] = field(default=False, metadata={"help": "Use score scaling"})
use_score_norm: Optional[bool] = field(
default=False, metadata={"help": "Use score normalization. Only applicable if use_score_scaling is True"}
)
score_clip: Optional[float] = field(default=None, metadata={"help": "Score clipping"})
parser = HfArgumentParser(ScriptArguments)
@ -72,8 +76,13 @@ config = PPOConfig(
target_kl=script_args.target_kl,
kl_penalty=script_args.kl_penalty,
seed=script_args.seed,
use_score_scaling=script_args.use_score_scaling,
use_score_norm=script_args.use_score_norm,
score_clip=script_args.score_clip,
)
# set seed before initializing value head for deterministic eval
set_seed(config.seed)
# We then define the arguments to pass to the sentiment analysis pipeline.
# We set `return_all_scores` to True to get the sentiment score for each token.
@ -127,9 +136,6 @@ def collator(data):
return dict((key, [d[key] for d in data]) for key in data[0])
# set seed before initializing value head for deterministic eval
set_seed(config.seed)
# Now let's build the model, the reference model, and the tokenizer.
if not script_args.use_peft:
ref_model = trl_model_class.from_pretrained(config.model_name)

View File

@ -57,6 +57,12 @@ class ScriptArguments:
use_auth_token: Optional[bool] = field(default=True, metadata={"help": "Use HF auth token to access the model"})
num_train_epochs: Optional[int] = field(default=3, metadata={"help": "the number of training epochs"})
max_steps: Optional[int] = field(default=-1, metadata={"help": "the number of training steps"})
save_steps: Optional[int] = field(
default=100, metadata={"help": "Number of updates steps before two checkpoint saves"}
)
save_total_limit: Optional[int] = field(default=10, metadata={"help": "Limits total number of checkpoints."})
push_to_hub: Optional[bool] = field(default=False, metadata={"help": "Push the model to HF Hub"})
hub_model_id: Optional[str] = field(default=None, metadata={"help": "The name of the model on HF Hub"})
parser = HfArgumentParser(ScriptArguments)
@ -98,6 +104,11 @@ training_args = TrainingArguments(
logging_steps=script_args.logging_steps,
num_train_epochs=script_args.num_train_epochs,
max_steps=script_args.max_steps,
report_to=script_args.log_with,
save_steps=script_args.save_steps,
save_total_limit=script_args.save_total_limit,
push_to_hub=script_args.push_to_hub,
hub_model_id=script_args.hub_model_id,
)
# Step 4: Define the LoraConfig

View File

@ -0,0 +1,230 @@
# Copyright 2023 metric-space, The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import numpy as np
import torch
import torch.nn as nn
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import EntryNotFoundError
from transformers import CLIPModel, CLIPProcessor
from trl import DDPOConfig, DDPOTrainer, DefaultDDPOStableDiffusionPipeline
class MLP(nn.Module):
def __init__(self):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(768, 1024),
nn.Dropout(0.2),
nn.Linear(1024, 128),
nn.Dropout(0.2),
nn.Linear(128, 64),
nn.Dropout(0.1),
nn.Linear(64, 16),
nn.Linear(16, 1),
)
@torch.no_grad()
def forward(self, embed):
return self.layers(embed)
class AestheticScorer(torch.nn.Module):
"""
This model attempts to predict the aesthetic score of an image. The aesthetic score
is a numerical approximation of how much a specific image is liked by humans on average.
This is from https://github.com/christophschuhmann/improved-aesthetic-predictor
"""
def __init__(self, *, dtype, model_id, model_filename):
super().__init__()
self.clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
self.mlp = MLP()
try:
cached_path = hf_hub_download(model_id, model_filename)
except EntryNotFoundError:
cached_path = os.path.join(model_id, model_filename)
state_dict = torch.load(cached_path)
self.mlp.load_state_dict(state_dict)
self.dtype = dtype
self.eval()
@torch.no_grad()
def __call__(self, images):
device = next(self.parameters()).device
inputs = self.processor(images=images, return_tensors="pt")
inputs = {k: v.to(self.dtype).to(device) for k, v in inputs.items()}
embed = self.clip.get_image_features(**inputs)
# normalize embedding
embed = embed / torch.linalg.vector_norm(embed, dim=-1, keepdim=True)
return self.mlp(embed).squeeze(1)
def aesthetic_scorer(hub_model_id, model_filename):
scorer = AestheticScorer(
model_id=hub_model_id,
model_filename=model_filename,
dtype=torch.float32,
).cuda()
def _fn(images, prompts, metadata):
images = (images * 255).round().clamp(0, 255).to(torch.uint8)
scores = scorer(images)
return scores, {}
return _fn
# list of example prompts to feed stable diffusion
animals = [
"cat",
"dog",
"horse",
"monkey",
"rabbit",
"zebra",
"spider",
"bird",
"sheep",
"deer",
"cow",
"goat",
"lion",
"frog",
"chicken",
"duck",
"goose",
"bee",
"pig",
"turkey",
"fly",
"llama",
"camel",
"bat",
"gorilla",
"hedgehog",
"kangaroo",
]
def prompt_fn():
return np.random.choice(animals), {}
def image_outputs_logger(image_data, global_step, accelerate_logger):
# For the sake of this example, we will only log the last batch of images
# and associated data
result = {}
images, prompts, _, rewards, _ = image_data[-1]
for i, image in enumerate(images):
prompt = prompts[i]
reward = rewards[i].item()
result[f"{prompt:.25} | {reward:.2f}"] = image.unsqueeze(0)
accelerate_logger.log_images(
result,
step=global_step,
)
def parse_arguments():
parser = argparse.ArgumentParser(description="DDPOConfig settings and pretrained model details.")
# DDPOConfig arguments
parser.add_argument("--num_epochs", type=int, default=200)
parser.add_argument("--train_gradient_accumulation_steps", type=int, default=1)
parser.add_argument("--sample_num_steps", type=int, default=50)
parser.add_argument("--sample_batch_size", type=int, default=6)
parser.add_argument("--train_batch_size", type=int, default=3)
parser.add_argument("--sample_num_batches_per_epoch", type=int, default=4)
parser.add_argument("--per_prompt_stat_tracking", action="store_true", default=True)
parser.add_argument("--per_prompt_stat_tracking_buffer_size", type=int, default=32)
parser.add_argument("--tracker_project_name", default="stable_diffusion_training")
parser.add_argument("--log_with", default="wandb")
parser.add_argument("--logging_dir", default="./logs")
parser.add_argument("--automatic_checkpoint_naming", action="store_true", default=True)
parser.add_argument("--total_limit", type=int, default=5)
parser.add_argument("--project_dir", default="./save")
parser.add_argument("--pretrained_model", default="runwayml/stable-diffusion-v1-5")
parser.add_argument("--pretrained_revision", default="main")
parser.add_argument("--hf_user_access_token", required=True)
parser.add_argument(
"--hf_hub_model_id",
help="HuggingFace repo to save model weights to",
default="ddpo-finetuned-stable-diffusion",
)
parser.add_argument(
"--hf_hub_aesthetic_model_id",
help="HuggingFace model ID for aesthetic scorer model weights",
default="trl-lib/ddpo-aesthetic-predictor",
)
parser.add_argument(
"--hf_hub_aesthetic_model_filename",
default="aesthetic-model.pth",
help="HuggingFace model filename for aesthetic scorer model weights",
)
return parser.parse_args()
if __name__ == "__main__":
args = parse_arguments()
project_kwargs = {
"logging_dir": args.logging_dir,
"automatic_checkpoint_naming": args.automatic_checkpoint_naming,
"total_limit": args.total_limit,
"project_dir": args.project_dir,
}
config = DDPOConfig(
num_epochs=args.num_epochs,
train_gradient_accumulation_steps=args.train_gradient_accumulation_steps,
sample_num_steps=args.sample_num_steps,
sample_batch_size=args.sample_batch_size,
train_batch_size=args.train_batch_size,
sample_num_batches_per_epoch=args.sample_num_batches_per_epoch,
per_prompt_stat_tracking=args.per_prompt_stat_tracking,
per_prompt_stat_tracking_buffer_size=args.per_prompt_stat_tracking_buffer_size,
tracker_project_name=args.tracker_project_name,
log_with=args.log_with,
project_kwargs=project_kwargs,
)
pipeline = DefaultDDPOStableDiffusionPipeline(
args.pretrained_model, pretrained_model_revision=args.pretrained_revision, use_lora=True
)
trainer = DDPOTrainer(
config,
aesthetic_scorer(args.hf_hub_aesthetic_model_id, args.hf_hub_aesthetic_model_filename),
prompt_fn,
pipeline,
image_samples_hook=image_outputs_logger,
)
trainer.train()
trainer.push_to_hub(args.hf_hub_model_id, token=args.hf_user_access_token)

View File

@ -3,4 +3,4 @@ torch>=1.4.0
tqdm
transformers
accelerate
peft>=0.3.0
peft>=0.3.0

View File

@ -30,7 +30,7 @@ LABELS_TO_EXEMPT = [
def main():
g = Github(os.environ["GITHUB_TOKEN"])
repo = g.get_repo("lvwerra/trl")
repo = g.get_repo("huggingface/trl")
open_issues = repo.get_issues(state="open")
for issue in open_issues:

View File

@ -57,7 +57,7 @@ To create the package for pypi.
from setuptools import find_packages, setup
__version__ = "0.5.0.dev0" # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
__version__ = "0.7.0" # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
REQUIRED_PKGS = [
"torch>=1.4.0",
@ -67,9 +67,10 @@ REQUIRED_PKGS = [
"datasets",
]
EXTRAS = {
"test": ["parameterized", "pytest", "pytest-xdist", "accelerate", "peft"],
"test": ["parameterized", "pytest", "pytest-xdist", "accelerate", "peft", "diffusers>=0.18.0"],
"peft": ["peft>=0.2.0"],
"dev": ["parameterized", "pytest", "pytest-xdist", "pre-commit", "peft>=0.2.0"],
"diffusers": ["diffusers>=0.18.0"],
"dev": ["parameterized", "pytest", "pytest-xdist", "pre-commit", "peft>=0.2.0", "diffusers>=0.18.0"],
}
setup(
@ -88,7 +89,7 @@ setup(
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
],
url="https://github.com/lvwerra/trl",
url="https://github.com/huggingface/trl",
packages=find_packages(),
include_package_data=True,
install_requires=REQUIRED_PKGS,

View File

@ -0,0 +1,66 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from transformers import AutoTokenizer
from trl import DataCollatorForCompletionOnlyLM
class DataCollatorForCompletionOnlyLMTester(unittest.TestCase):
def test_data_collator_finds_response_template_llama2_tokenizer(self):
self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/dummy-GPT2-correct-vocab")
self.instruction = """### System: You are a helpful assistant.
### User: How much is 2+2?
### Assistant: 2+2 equals 4"""
self.response_template = "\n### Assistant:"
# GPT2Tokenizer: [198, 21017, 15286, 25] -> [15286, 25]
# Llama2Tokenizer: [29871, 13, 2277, 29937, 4007, 22137, 29901] -> [2277, 29937, 4007, 22137, 29901]
self.tokenized_response_w_context = self.tokenizer.encode(self.response_template, add_special_tokens=False)[2:]
# Plain check on string
self.assertIn(self.response_template, self.instruction)
self.tokenized_instruction = self.tokenizer.encode(self.instruction, add_special_tokens=False)
# Test the fix for #598
# Pass already tokenized (w context) and truncated response_template so token_ids are like in the instruction + response
self.collator = DataCollatorForCompletionOnlyLM(self.tokenized_response_w_context, tokenizer=self.tokenizer)
self.collator.torch_call([self.tokenized_instruction])
def test_data_collator_handling_of_long_sequences(self):
self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/dummy-GPT2-correct-vocab")
self.instruction = """### System: You are a helpful assistant.
### User: How much is 2+2? I'm asking because I'm not sure. And I'm not sure because I'm not good at math.
"""
self.response_template = "\n### Assistant:"
# check DataCollatorForCompletionOnlyLM using response template only
self.tokenized_instruction = self.tokenizer.encode(self.instruction, add_special_tokens=False)
self.collator = DataCollatorForCompletionOnlyLM(self.response_template, tokenizer=self.tokenizer)
encoded_instance = self.collator.torch_call([self.tokenized_instruction])
result = torch.all(encoded_instance["labels"] == -100)
self.assertTrue(result, "Not all values in the tensor are -100.")
# check DataCollatorForCompletionOnlyLM using response template and instruction template
self.instruction_template = "\n### User:"
self.collator = DataCollatorForCompletionOnlyLM(
self.response_template, self.instruction_template, tokenizer=self.tokenizer
)
encoded_instance = self.collator.torch_call([self.tokenized_instruction])
result = torch.all(encoded_instance["labels"] == -100)
self.assertTrue(result, "Not all values in the tensor are -100.")

View File

@ -0,0 +1,92 @@
# Copyright 2023 metric-space, The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import unittest
import torch
from trl import DDPOConfig, DDPOTrainer, DefaultDDPOStableDiffusionPipeline
def scorer_function(images, prompts, metadata):
return torch.randn(1) * 3.0, {}
def prompt_function():
return ("cabbages", {})
class DDPOTrainerTester(unittest.TestCase):
"""
Test the DDPOTrainer class.
"""
def setUp(self):
self.ddpo_config = DDPOConfig(
num_epochs=2,
train_gradient_accumulation_steps=1,
per_prompt_stat_tracking_buffer_size=32,
sample_num_batches_per_epoch=2,
sample_batch_size=2,
mixed_precision=None,
save_freq=1000000,
)
pretrained_model = "hf-internal-testing/tiny-stable-diffusion-torch"
pretrained_revision = "main"
pipeline = DefaultDDPOStableDiffusionPipeline(
pretrained_model, pretrained_model_revision=pretrained_revision, use_lora=False
)
self.trainer = DDPOTrainer(self.ddpo_config, scorer_function, prompt_function, pipeline)
return super().setUp()
def tearDown(self) -> None:
gc.collect()
def test_loss(self):
advantage = torch.tensor([-1.0])
clip_range = 0.0001
ratio = torch.tensor([1.0])
loss = self.trainer.loss(advantage, clip_range, ratio)
self.assertEqual(loss.item(), 1.0)
def test_generate_samples(self):
samples, output_pairs = self.trainer._generate_samples(1, 2)
self.assertEqual(len(samples), 1)
self.assertEqual(len(output_pairs), 1)
self.assertEqual(len(output_pairs[0][0]), 2)
def test_calculate_loss(self):
samples, _ = self.trainer._generate_samples(1, 2)
sample = samples[0]
latents = sample["latents"][0, 0].unsqueeze(0)
next_latents = sample["next_latents"][0, 0].unsqueeze(0)
log_probs = sample["log_probs"][0, 0].unsqueeze(0)
timesteps = sample["timesteps"][0, 0].unsqueeze(0)
prompt_embeds = sample["prompt_embeds"]
advantage = torch.tensor([1.0], device=prompt_embeds.device)
self.assertEqual(latents.shape, (1, 4, 64, 64))
self.assertEqual(next_latents.shape, (1, 4, 64, 64))
self.assertEqual(log_probs.shape, (1,))
self.assertEqual(timesteps.shape, (1,))
self.assertEqual(prompt_embeds.shape, (2, 77, 32))
loss, approx_kl, clipfrac = self.trainer.calculate_loss(
latents, timesteps, next_latents, log_probs, advantage, prompt_embeds
)
self.assertTrue(torch.isfinite(loss.cpu()))

View File

@ -16,10 +16,13 @@ import unittest
import torch
from datasets import Dataset
from pytest import mark
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from trl import DPOTrainer
from .testing_utils import require_peft
class DPOTrainerTester(unittest.TestCase):
@classmethod
@ -30,6 +33,40 @@ class DPOTrainerTester(unittest.TestCase):
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_id)
cls.tokenizer.pad_token = cls.tokenizer.eos_token
def _init_dummy_dataset(self):
# fmt: off
dummy_dataset_dict = {
"prompt": [
"hello",
"how are you",
"What is your name?",
"What is your name?",
"Which is the best programming language?",
"Which is the best programming language?",
"Which is the best programming language?",
],
"chosen": [
"hi nice to meet you",
"I am fine",
"My name is Mary",
"My name is Mary",
"Python",
"Python",
"Python",
],
"rejected": [
"leave me alone",
"I am not fine",
"Whats it to you?",
"I dont have a name",
"Javascript",
"C++",
"Java",
],
}
# fmt: on
return Dataset.from_dict(dummy_dataset_dict)
def test_dpo_trainer(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
@ -42,38 +79,7 @@ class DPOTrainerTester(unittest.TestCase):
evaluation_strategy="steps",
)
# fmt: off
dummy_dataset_dict = {
"prompt": [
"hello",
"how are you",
"What is your name?",
"What is your name?",
"Which is the best programming language?",
"Which is the best programming language?",
"Which is the best programming language?",
],
"chosen": [
"hi nice to meet you",
"I am fine",
"My name is Mary",
"My name is Mary",
"Python",
"Python",
"Python",
],
"rejected": [
"leave me alone",
"I am not fine",
"Whats it to you?",
"I dont have a name",
"Javascript",
"C++",
"Java",
],
}
# fmt: on
dummy_dataset = Dataset.from_dict(dummy_dataset_dict)
dummy_dataset = self._init_dummy_dataset()
trainer = DPOTrainer(
model=self.model,
@ -97,3 +103,91 @@ class DPOTrainerTester(unittest.TestCase):
# check the params have changed - ignore 0 biases
if param.sum() != 0:
self.assertFalse(torch.equal(param, new_param))
def test_dpo_trainer_without_providing_ref_model(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
remove_unused_columns=False,
gradient_accumulation_steps=4,
learning_rate=9e-1,
evaluation_strategy="steps",
)
dummy_dataset = self._init_dummy_dataset()
trainer = DPOTrainer(
model=self.model,
ref_model=None,
beta=0.1,
args=training_args,
tokenizer=self.tokenizer,
train_dataset=dummy_dataset,
eval_dataset=dummy_dataset,
)
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
trainer.train()
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
# check the params have changed - ignore 0 biases
if param.sum() != 0:
self.assertFalse(torch.equal(param, new_param))
@require_peft
@mark.peft_test
def test_dpo_trainer_without_providing_ref_model_with_lora(self):
from peft import LoraConfig
lora_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = TrainingArguments(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
remove_unused_columns=False,
gradient_accumulation_steps=4,
learning_rate=9e-1,
evaluation_strategy="steps",
)
dummy_dataset = self._init_dummy_dataset()
trainer = DPOTrainer(
model=self.model,
ref_model=None,
beta=0.1,
args=training_args,
tokenizer=self.tokenizer,
train_dataset=dummy_dataset,
eval_dataset=dummy_dataset,
peft_config=lora_config,
)
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
trainer.train()
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# check the params have changed
for n, param in previous_trainable_params.items():
if "lora" in n:
new_param = trainer.model.get_parameter(n)
# check the params have changed - ignore 0 biases
if param.sum() != 0:
self.assertFalse(torch.equal(param, new_param))

273
tests/test_environments.py Normal file
View File

@ -0,0 +1,273 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from unittest.mock import patch
import torch
from transformers import AutoTokenizer
from trl import AutoModelForCausalLMWithValueHead, TextEnvironment, TextHistory
class DummyTool:
def __call__(self, text):
return text
def dummy_generate(histories):
for i in range(len(histories)):
histories[i].append_segment("<request><DummyTool>test<call>", torch.tensor([1, 2, 3]), system=False)
return histories
class TextHistoryTest(unittest.TestCase):
def test_text_history_init(self):
text = "Hello there!"
tokens = torch.tensor([1, 2, 3])
history = TextHistory(text, tokens)
self.assertEqual(history.text, text)
self.assertTrue(torch.equal(history.tokens, tokens))
self.assertTrue(torch.equal(history.token_masks, torch.zeros_like(tokens)))
history = TextHistory(text, tokens, system=False)
self.assertTrue(torch.equal(history.token_masks, torch.ones_like(tokens)))
def test_text_history_append_segment(self):
text = "Hello there!"
tokens = torch.tensor([1, 2, 3])
history = TextHistory(text, tokens)
history.append_segment("General Kenobi!", torch.tensor([4, 5, 6]), system=False)
self.assertEqual(history.text, text + "General Kenobi!")
self.assertTrue(torch.equal(history.tokens, torch.tensor([1, 2, 3, 4, 5, 6])))
self.assertTrue(torch.equal(history.token_masks, torch.tensor([0, 0, 0, 1, 1, 1])))
history.append_segment("You are a bold one!", torch.tensor([7, 8, 9]))
self.assertEqual(history.text, text + "General Kenobi!" + "You are a bold one!")
self.assertTrue(torch.equal(history.tokens, torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9])))
self.assertTrue(torch.equal(history.token_masks, torch.tensor([0, 0, 0, 1, 1, 1, 0, 0, 0])))
def test_text_history_complete(self):
text = "Hello there!"
tokens = torch.tensor([1, 2, 3])
history = TextHistory(text, tokens)
history.complete()
self.assertTrue(history.completed)
self.assertFalse(history.truncated)
history.complete(truncated=True)
self.assertTrue(history.completed)
self.assertTrue(history.truncated)
def test_text_history_last_segment(self):
text = "Hello there!"
tokens = torch.tensor([1, 2, 3])
history = TextHistory(text, tokens)
history.append_segment("General Kenobi!", torch.tensor([4, 5, 6]))
history.append_segment("You are a bold one!", torch.tensor([7, 8, 9]))
self.assertEqual(history.last_text_segment, "You are a bold one!")
def test_text_history_split_query_response(self):
text = "Hello there!"
tokens = torch.tensor([1, 2, 3])
history = TextHistory(text, tokens)
history.append_segment("General Kenobi!", torch.tensor([4, 5, 6]), system=False)
history.append_segment("You are a bold one!", torch.tensor([7, 8, 9]), system=True)
query, response, mask = history.split_query_response_tokens()
self.assertTrue(torch.equal(query, torch.tensor([1, 2, 3])))
self.assertTrue(torch.equal(response, torch.tensor([4, 5, 6, 7, 8, 9])))
self.assertTrue(torch.equal(mask, torch.tensor([1, 1, 1, 0, 0, 0])))
class TextEnvironmentTester(unittest.TestCase):
@classmethod
def setUpClass(cls):
# model_id
cls.model_id = "trl-internal-testing/dummy-GPT2-correct-vocab"
# get models and tokenizer
cls.gpt2_model = AutoModelForCausalLMWithValueHead.from_pretrained(cls.model_id)
cls.gpt2_tokenizer = AutoTokenizer.from_pretrained(cls.model_id)
cls.gpt2_tokenizer.pad_token = cls.gpt2_tokenizer.eos_token
def test_text_environment_setup(self):
env = TextEnvironment(
self.gpt2_model,
self.gpt2_tokenizer,
tools=[DummyTool()],
reward_fn=lambda x: torch.tensor(1),
prompt="I am a prompt!\n",
)
self.assertEqual(env.prompt, "I am a prompt!\n")
self.assertEqual(list(env.tools.keys()), ["DummyTool"])
self.assertTrue(isinstance(env.tools["DummyTool"], DummyTool))
self.assertEqual(env.reward_fn("Hello there!"), 1)
def test_text_environment_generate(self):
generation_kwargs = {"do_sample": False, "max_new_tokens": 4, "pad_token_id": self.gpt2_tokenizer.eos_token_id}
env = TextEnvironment(
self.gpt2_model,
self.gpt2_tokenizer,
tools=[DummyTool()],
reward_fn=lambda x: torch.tensor(1),
prompt="I am a prompt!\n",
generation_kwargs=generation_kwargs,
)
input_texts = ["this is a test", "this is another, longer test"]
model_inputs = [self.gpt2_tokenizer(txt, return_tensors="pt").input_ids.squeeze() for txt in input_texts]
generations_batched = env._generate_batched(model_inputs, batch_size=2)
generations_batched = self.gpt2_tokenizer.batch_decode(generations_batched)
generations_single = [env._generate_batched([inputs], batch_size=1)[0] for inputs in model_inputs]
generations_single = self.gpt2_tokenizer.batch_decode(generations_single)
self.assertEqual(generations_single, generations_batched)
def test_text_environment_tool_call_parsing(self):
string_valid = "Something something <request><Tool1>Hello there!<call>"
string_invalid_request = "Something something <Tool1>Hello there!<call>"
string_invalid_call = "Something something <request><Tool1>Hello there!"
string_invalid_tool = "Something something <request>|Tool2|Hello there!<call>"
string_invalid_random = "<>abcdefghijklm<>nopqrstuvwxyz<>"
env = TextEnvironment(
self.gpt2_model,
self.gpt2_tokenizer,
tools=[DummyTool()],
reward_fn=lambda x: torch.tensor(1),
prompt="I am a prompt!\n",
)
tool, response = env.parse_tool_call(string_valid)
self.assertEqual(tool, "Tool1")
self.assertEqual(response, "Hello there!")
tool, response = env.parse_tool_call(string_invalid_request)
self.assertEqual(tool, None)
self.assertEqual(response, None)
tool, response = env.parse_tool_call(string_invalid_call)
self.assertEqual(tool, None)
self.assertEqual(response, None)
tool, response = env.parse_tool_call(string_invalid_tool)
self.assertEqual(tool, None)
self.assertEqual(response, None)
tool, response = env.parse_tool_call(string_invalid_random)
self.assertEqual(tool, None)
self.assertEqual(response, None)
def test_text_environment_tool_truncation(self):
env = TextEnvironment(
self.gpt2_model,
self.gpt2_tokenizer,
tools={"dummy": lambda x: "a" * 1000},
reward_fn=lambda x: torch.tensor(1),
prompt="I am a prompt!\n",
)
env.max_tool_response = 100
history = env.step(TextHistory("<request><dummy>Hello there!<call>", torch.tensor([1, 2, 3])))
self.assertEqual(len(history.last_text_segment) - len(env.response_token), 100)
env.max_tool_response = 500
history = env.step(TextHistory("<request><dummy>Hello there!<call>", torch.tensor([1, 2, 3])))
self.assertEqual(len(history.last_text_segment) - len(env.response_token), 500)
env.max_tool_response = 1001
history = env.step(TextHistory("<request><dummy>Hello there!<call>", torch.tensor([1, 2, 3])))
self.assertEqual(len(history.last_text_segment) - len(env.response_token), 1000)
env.max_tool_response = 2000
history = env.step(TextHistory("<request><dummy>Hello there!<call>", torch.tensor([1, 2, 3])))
self.assertEqual(len(history.last_text_segment) - len(env.response_token), 1000)
@patch.object(TextEnvironment, "generate", side_effect=dummy_generate)
def test_text_environment_max_calls(self, mock_generate):
env = TextEnvironment(
self.gpt2_model,
self.gpt2_tokenizer,
tools={"DummyTool": DummyTool()},
reward_fn=lambda x: [torch.tensor(1) for _ in x],
prompt="I am a prompt!\n",
)
env.max_turns = 1
_, _, _, _, histories = env.run(["test"])
self.assertEqual(
histories[0].text, "I am a prompt!\n" + "test" + 1 * "<request><DummyTool>test<call>test<response>"
)
env.max_turns = 2
_, _, _, _, histories = env.run(["test"])
self.assertEqual(
histories[0].text, "I am a prompt!\n" + "test" + 2 * "<request><DummyTool>test<call>test<response>"
)
env.max_turns = 4
_, _, _, _, histories = env.run(["test"])
self.assertEqual(
histories[0].text, "I am a prompt!\n" + "test" + 4 * "<request><DummyTool>test<call>test<response>"
)
def test_text_environment_compute_rewards(self):
env = TextEnvironment(
self.gpt2_model,
self.gpt2_tokenizer,
tools={"DummyTool": DummyTool()},
reward_fn=lambda x: [torch.tensor(i) for i, _ in enumerate(x)],
prompt="I am a prompt!\n",
)
histories = [TextHistory("<request><DummyTool>test<call>", torch.tensor([1, 2, 3])) for _ in range(8)]
histories = env.compute_reward(histories)
for i in range(8):
self.assertEqual(histories[i].reward, i)
@patch.object(TextEnvironment, "generate", side_effect=dummy_generate)
def test_text_environment_run(self, mock_generate):
env = TextEnvironment(
self.gpt2_model,
self.gpt2_tokenizer,
tools={"DummyTool": DummyTool()},
reward_fn=lambda x: [torch.tensor(i) for i, _ in enumerate(x)],
prompt="I am a prompt!\n",
max_turns=2,
)
task_1 = "Hello there!"
task_2 = "Hello there! General Kenobi!"
query, response, response_mask, reward, histories = env.run([task_1, task_2])
self.assertEqual(len(query[0]), 9)
self.assertEqual(len(query[1]), 12)
self.assertEqual(len(response[0]), 14)
self.assertEqual(len(response[1]), 14)
self.assertEqual(response_mask[0].sum(), 2 * 3) # mocked generate always adds 3 toknes
self.assertEqual(response_mask[1].sum(), 2 * 3) # mocked generate always adds 3 toknes
self.assertEqual(reward[0], 0)
self.assertEqual(reward[1], 1)
self.assertEqual(
histories[0].text, "I am a prompt!\n" + "Hello there!" + 2 * "<request><DummyTool>test<call>test<response>"
)
self.assertEqual(
histories[1].text,
"I am a prompt!\n" + "Hello there! General Kenobi!" + 2 * "<request><DummyTool>test<call>test<response>",
)

View File

@ -18,6 +18,7 @@ import re
import tempfile
import unittest
import pytest
import torch
from huggingface_hub import HfApi, HfFolder, delete_repo
from parameterized import parameterized
@ -208,6 +209,38 @@ class PPOTrainerTester(unittest.TestCase):
for stat in EXPECTED_STATS:
assert stat in train_stats.keys()
def test_ppo_step_with_masks(self):
# initialize dataset
dummy_dataset = self._init_dummy_dataset()
ppo_trainer = PPOTrainer(
config=self.ppo_config,
model=self.gpt2_model,
ref_model=self.gpt2_model_ref,
tokenizer=self.gpt2_tokenizer,
dataset=dummy_dataset,
)
dummy_dataloader = ppo_trainer.dataloader
# train model with ppo
for query_tensor, response_tensor in dummy_dataloader:
# define a reward for response
# (this could be any reward such as human feedback or output from another model)
reward = [torch.tensor(1.0), torch.tensor(0.0)]
response_mask = [torch.ones_like(r) for r in response_tensor]
# train model
train_stats = ppo_trainer.step(
[q for q in query_tensor], [r for r in response_tensor], reward, response_mask
)
break
for param in ppo_trainer.model.parameters():
assert param.grad is not None
for stat in EXPECTED_STATS:
assert stat in train_stats.keys()
def test_ppo_step_with_no_ref_sgd(self):
# initialize dataset
dummy_dataset = self._init_dummy_dataset()
@ -465,7 +498,7 @@ class PPOTrainerTester(unittest.TestCase):
# train model - this should raise an error
bs = ppo_trainer.config.batch_size
queries, responses, _ = ppo_trainer._step_safety_checker(
queries, responses, _, _ = ppo_trainer._step_safety_checker(
bs, [q for q in query_tensor], [r for r in response_tensor], reward
)
@ -1193,3 +1226,7 @@ class PPOTrainerTester(unittest.TestCase):
# train model
_ = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward)
break
def test_batch_size_check(self):
with pytest.raises(ValueError):
PPOConfig(batch_size=2, mini_batch_size=2, gradient_accumulation_steps=2)

View File

@ -1,10 +1,11 @@
# flake8: noqa
__version__ = "0.4.8.dev0"
__version__ = "0.7.0"
from .core import set_seed
from .environment import TextEnvironment, TextHistory
from .extras import BestOfNSampler
from .import_utils import is_peft_available
from .import_utils import is_diffusers_available, is_peft_available
from .models import (
AutoModelForCausalLMWithValueHead,
AutoModelForSeq2SeqLMWithValueHead,
@ -12,3 +13,13 @@ from .models import (
create_reference_model,
)
from .trainer import DataCollatorForCompletionOnlyLM, DPOTrainer, PPOConfig, PPOTrainer, RewardTrainer, SFTTrainer
if is_diffusers_available():
from .models import (
DDPOPipelineOutput,
DDPOSchedulerOutput,
DDPOStableDiffusionPipeline,
DefaultDDPOStableDiffusionPipeline,
)
from .trainer import DDPOConfig, DDPOTrainer

View File

@ -0,0 +1,3 @@
# flake8: noqa
from .base_environment import TextEnvironment, TextHistory

View File

@ -0,0 +1,473 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
import warnings
import torch
from accelerate.utils import extract_model_from_parallel
from transformers import StoppingCriteria, StoppingCriteriaList
from ..import_utils import is_rich_available
if is_rich_available():
from rich import print
from rich.text import Text
class StringStoppingCriteria(StoppingCriteria):
"""Custom `StoppingCriteria` which checks if all generations in the batch are completed."""
def __init__(self, stop_strings, tokenizer):
self.stop_strings = stop_strings
self.tokenizer = tokenizer
self.first_call = True
def __call__(self, input_ids, scores, **kwargs):
"""Returns true if all generated sequences contain any of the stop strings."""
if self.first_call:
self.generated_tokens = [1 for _ in range(input_ids.shape[0])]
self.start_length = input_ids.shape[-1] - 1
self.first_call = False
decoded_generations = self.tokenizer.batch_decode(input_ids[:, self.start_length :])
done = []
for i, decoded_generation in enumerate(decoded_generations):
sequence_complete = any([stop_string in decoded_generation for stop_string in self.stop_strings])
done.append(sequence_complete)
if not sequence_complete:
self.generated_tokens[i] += 1
if all(done):
self.first_call = True
return all(done)
class TextHistory:
"""The TextHistory class keeps track of the history of an interaction between the language model and the environment."""
def __init__(self, text, tokens, system=True):
"""
Initialize TextHistory.
args:
text (`str`): The text of the first segment.
tokens (`torch.LongTensor`): The tokens of the first segment.
system (`bool`, *optional*): Whether the first segment is a system or user segment.
"""
self.system_spans = []
self.text_spans = []
self.token_spans = []
self.token_masks = torch.tensor([], dtype=torch.long).to(tokens.device)
self.text = ""
self.tokens = torch.tensor([], dtype=torch.long).to(tokens.device)
self.completed = False
self.truncated = False
self.reward = 0.0
self.prompt_color = "black on grey85"
self.system_color = "black on cyan3"
self.model_color = "black on deep_sky_blue1"
self.reward_color = "black on plum1"
self.append_segment(text, tokens, system=system)
def append_segment(self, text, tokens, system=True):
"""
Append a new segment to the history.
args:
text (`str`): The text of the new segment.
tokens (`torch.LongTensor`): The tokens of the new segment.
system (`bool`, *optional*): Whether the new segment is a system or user segment.
"""
if len(text) == 0 or len(tokens) == 0:
raise ValueError("Can't append empty text or token list to history.")
original_text_length = len(self.text)
self.text += text
self.text_spans.append((original_text_length, len(self.text)))
self.system_spans.append(system)
original_token_length = len(self.tokens)
self.tokens = torch.cat((self.tokens, tokens))
if system:
self.token_masks = torch.cat((self.token_masks, torch.zeros_like(tokens)))
else:
self.token_masks = torch.cat((self.token_masks, torch.ones_like(tokens)))
self.token_spans.append((original_token_length, len(self.tokens)))
def complete(self, truncated=False):
"""
Mark the history as completed.
"""
self.completed = True
self.truncated = truncated
@property
def last_text_segment(self):
"""
Get the last text segment.
"""
start, end = self.text_spans[-1]
return self.text[start:end]
def split_query_response_tokens(self):
"""
Split the tokens into query and response tokens.
"""
split_index = self.token_spans[0][1]
query = self.tokens[:split_index]
response = self.tokens[split_index:]
mask = self.token_masks[split_index:]
return query, response, mask
def show_text(self, show_legend=False):
"""
Print the text history.
"""
if not is_rich_available():
warnings.warn("install rich to display text")
return
text = Text(self.text)
text.stylize(self.prompt_color, self.text_spans[0][0], self.text_spans[1][0])
for i, (start, end) in enumerate(self.text_spans[1:]):
if self.system_spans[i + 1]:
text.stylize(self.system_color, start, end)
else:
text.stylize(self.model_color, start, end)
text.append(f"\n\nReward: {self.reward}", style=self.reward_color)
print(text)
if show_legend:
self.show_colour_legend()
def show_tokens(self, tokenizer, show_legend=False):
"""
Print the history tokens.
"""
if not is_rich_available():
warnings.warn("install rich to display tokens")
return
text = Text()
prompt_end = self.token_spans[0][1]
for i, (token, mask) in enumerate(zip(self.tokens, self.token_masks)):
if i < prompt_end:
text.append(tokenizer.convert_ids_to_tokens(token.item()), style=self.prompt_color)
text.append(" ")
elif mask == 0:
text.append(tokenizer.convert_ids_to_tokens(token.item()), style=self.system_color)
text.append(" ")
else:
text.append(tokenizer.convert_ids_to_tokens(token.item()), style=self.model_color)
text.append(" ")
text.append(f"\n\nReward: {self.reward}", style=self.reward_color)
print(text)
if show_legend:
self.show_colour_legend()
def show_colour_legend(self):
"""
Print the colour legend.
"""
if not is_rich_available():
warnings.warn("install rich to display colour legend")
return
text = Text("\n\n(Colour Legend: ")
text.append("Prompt", style=self.prompt_color)
text.append("|")
text.append("System", style=self.system_color)
text.append("|")
text.append("Model", style=self.model_color)
text.append("|")
text.append("Reward", style=self.reward_color)
text.append(")")
print(text)
class TextEnvironment:
"""
The TextEnvironment enables interaction of a LLM with an environment using tools.
"""
def __init__(
self,
model=None,
tokenizer=None,
tools=None,
reward_fn=None,
prompt=None,
max_turns=4,
max_tool_reponse=100,
max_length=None,
generation_kwargs=None,
):
"""
Initialize TextEnvironment.
Args:
model (`PreTrainedModelWrapper`): The model to use for generation.
tokenizer (`transformers.PreTrainedTokenizer`): The tokenizer to use for generation.
tools (list): A list of tools to use for interaction.
reward_fn (function): A function that takes a string and returns a reward.
prompt (str): The base prompt to use for generation. Is prepended to the tasks.
max_turns (Optional[int]): The maximum number of turns to allow.
max_tool_response (Optional[int]): The maximum number of characters to allow in a tool response.
max_length (Optional[int]): The maximum number of tokens to allow in an episode.
generation_kwargs (Optional[dict]): A dictionary of keyword arguments to pass to the model's generate method.
"""
self.model = model
self.tokenizer = tokenizer
self.prompt = prompt
if isinstance(tools, dict):
self.tools = tools
else:
self.tools = dict([(tool.__class__.__name__, tool) for tool in tools])
self.reward_fn = reward_fn
self.max_length = max_length
self.request_token = "<request>"
self.call_token = "<call>"
self.response_token = "<response>"
self.submit_token = "<submit>"
self.max_turns = max_turns
self.max_tool_response = max_tool_reponse
if generation_kwargs is None:
self.generation_kwargs = dict()
else:
self.generation_kwargs = generation_kwargs
self.is_encoder_decoder = hasattr(self.model, "is_encoder_decoder")
self.current_device = extract_model_from_parallel(self.model).pretrained_model.device
def run(self, queries, **rewards_kwargs):
"""
Run the environment on a list of queries.
Args:
queries (list[str]): A list of queries to run the model in the environment on.
"""
turns = 0
queries = [self.prompt + task for task in queries]
queries_tokens = [
self.tokenizer(query, return_tensors="pt").input_ids[0].to(self.model.pretrained_model.device)
for query in queries
]
histories = [TextHistory(q, qt, system=True) for q, qt in zip(queries, queries_tokens)]
while any([not history.completed for history in histories]) and turns < self.max_turns:
histories = self.generate(histories)
histories = self.tasks_end_check(histories)
# TODO: make this parallel rather than for-loop
for i in range(len(histories)):
histories[i] = self.step(histories[i])
histories = self.tasks_end_check(histories, model_turn=False)
turns += 1
self.compute_reward(histories, **rewards_kwargs)
# convert a list of (q, r, m) tuples to lists of all qs, rs, and ms respectively
queries, responses, masks = map(list, zip(*[history.split_query_response_tokens() for history in histories]))
rewards = [history.reward for history in histories]
return queries, responses, masks, rewards, histories
def step(self, history):
"""
Step the environment forward one turn.
Args:
history (`TextHistory`): The history to step forward.
"""
truncated, ended = self.task_end_check(history)
if ended:
history.complete(truncated=truncated)
if history.completed:
return history
tool, query = self.parse_tool_call(history.last_text_segment)
if tool is None or query is None:
response = f"Unknown tool call: {history.last_text_segment}"
else:
if tool not in self.tools:
response = f"Unknown tool {tool}."
try:
response = self.tools[tool](query)
except Exception as error:
response = f"Tool error: {str(error)}"
if len(response) > self.max_tool_response:
response = response[: (self.max_tool_response - 3)] + "..."
history.append_segment(
response + self.response_token,
self.tokenizer(response + self.response_token, return_tensors="pt")
.input_ids[0]
.to(self.model.pretrained_model.device),
system=True,
)
return history
def parse_tool_call(self, text):
"""
Parse request string. Expected format: <request><tool_name>query<call>
"""
result = re.search(f"(?<={self.request_token}).*?(?={self.call_token})", text, re.DOTALL)
# if we can't find a <request>/<call> span we return none
if result is None:
return None, None
else:
extracted_text = result.group()
result = re.search(r"<(.*?)>", extracted_text)
# if we can't find a tool name we return none
if result is None:
return None, None
else:
tool = result.group(1)
# split off the tool name
query = ">".join(extracted_text.split(">")[1:])
return tool, query
def compute_reward(self, histories, **reward_kwargs):
"""
Compute the reward for a list of histories.
"""
rewards = self.reward_fn([history.last_text_segment for history in histories], **reward_kwargs)
for history, reward in zip(histories, rewards):
history.reward = reward
return histories
def generate(self, histories):
"""
Generate responses for a list of histories.
"""
active_histories = [i for i, history in enumerate(histories) if not history.completed]
query_tensors = [histories[i].tokens for i in active_histories]
response_tensors = self._generate_batched(query_tensors)
response_texts = self.tokenizer.batch_decode(response_tensors)
for i, response_text, response_tensor in zip(active_histories, response_texts, response_tensors):
histories[i].append_segment(response_text, response_tensor, system=False)
return histories
def tasks_end_check(self, histories, model_turn=True):
"""
Check if the current generation sequences have finished.
"""
for history in histories:
if not history.completed:
truncated, ended = self.task_end_check(history, model_turn=model_turn)
if ended:
history.complete(truncated=truncated)
return histories
def task_end_check(self, history, model_turn=True):
"""
Check if the current generation sequence has finished.
"""
truncated = False
ended = False
if history.completed:
return truncated, ended
if self.max_length is not None and len(self.tokenizer(history.text).input_ids[0]) > self.max_length:
truncated = True
ended = True
elif self.tokenizer.eos_token in history.text:
ended = True
elif model_turn and not (
(self.request_token in history.last_text_segment and self.call_token in history.last_text_segment)
or self.submit_token in history.last_text_segment
):
ended = True
elif self.submit_token in history.last_text_segment:
ended = True
return truncated, ended
def _generate_batched(
self,
query_tensors,
batch_size: int = 16,
pad_to_multiple_of: int = None,
):
"""
Generate responses for a list of query tensors.
args:
query_tensors (list[torch.Tensor]): A list of query tensors to generate responses for.
batch_size (int): The batch size to use for generation.
pad_to_multiple_of (int): The padding length to use for generation.
"""
outputs = []
padding_side_default = self.tokenizer.padding_side
if not self.is_encoder_decoder:
self.tokenizer.padding_side = "left"
# in case we have fewer examples than bs
batch_size = min(len(query_tensors), batch_size)
for i in range(0, len(query_tensors), batch_size):
# prevent overflow if query tensors are not even multiple of bs
end_index = min(len(query_tensors), i + batch_size)
batch = query_tensors[i:end_index]
batch_mask = [torch.ones_like(element) for element in batch]
inputs = {"input_ids": batch, "attention_mask": batch_mask}
padded_inputs = self.tokenizer.pad(
inputs,
padding=True,
max_length=None,
pad_to_multiple_of=pad_to_multiple_of,
return_tensors="pt",
).to(self.current_device)
stopping_criteria = StringStoppingCriteria([self.call_token, self.submit_token], self.tokenizer)
self.generation_kwargs["stopping_criteria"] = StoppingCriteriaList([stopping_criteria])
generations = extract_model_from_parallel(self.model).generate(**padded_inputs, **self.generation_kwargs)
for generation, mask, generated_tokens in zip(
generations, padded_inputs["attention_mask"], stopping_criteria.generated_tokens
):
if not self.is_encoder_decoder:
output = generation[(1 - mask).sum() :] # remove padding
else:
output = generation
if not self.is_encoder_decoder:
output = output[(mask).sum() :] # remove prompt
# remove chunk generated after stopping criteria in batch mode
outputs.append(output[:generated_tokens])
self.tokenizer.padding_side = padding_side_default
return outputs

View File

@ -35,3 +35,19 @@ def is_torch_greater_2_0():
torch_version = pkg_resources.get_distribution("torch").version
return torch_version >= "2.0"
def is_diffusers_available():
return importlib.util.find_spec("diffusers") is not None
def is_bitsandbytes_available():
return importlib.util.find_spec("bitsandbytes") is not None
def is_torchvision_available():
return importlib.util.find_spec("torchvision") is not None
def is_rich_available():
return importlib.util.find_spec("rich") is not None

View File

@ -21,3 +21,14 @@ SUPPORTED_ARCHITECTURES = (
AutoModelForCausalLMWithValueHead,
AutoModelForSeq2SeqLMWithValueHead,
)
from ..import_utils import is_diffusers_available
if is_diffusers_available():
from .modeling_sd_base import (
DDPOPipelineOutput,
DDPOSchedulerOutput,
DDPOStableDiffusionPipeline,
DefaultDDPOStableDiffusionPipeline,
)

View File

@ -20,6 +20,7 @@ import torch
import torch.nn as nn
from accelerate import Accelerator
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import EntryNotFoundError, HFValidationError, LocalEntryNotFoundError
from transformers import PreTrainedModel
from ..import_utils import is_peft_available
@ -115,12 +116,14 @@ class PreTrainedModelWrapper(nn.Module):
reward_adapter = kwargs.pop("reward_adapter", None)
is_trainable = kwargs.pop("is_trainable", False)
trl_model_args, pretrained_kwargs, peft_quantization_kwargs = cls._split_kwargs(kwargs)
token = pretrained_kwargs.get("token", None)
else:
peft_config = None
is_trainable = False
trl_model_args = {}
pretrained_kwargs = {}
peft_quantization_kwargs = {}
token = None
if reward_adapter is not None and not isinstance(reward_adapter, str):
raise ValueError(
@ -156,8 +159,12 @@ class PreTrainedModelWrapper(nn.Module):
if is_peft_available():
try:
# If there is a trained peft adapter in the hub, load its config.
remote_adapter_config = hf_hub_download(pretrained_model_name_or_path, "adapter_config.json")
except: # noqa
remote_adapter_config = hf_hub_download(
pretrained_model_name_or_path,
"adapter_config.json",
token=token,
)
except (EntryNotFoundError, LocalEntryNotFoundError, HFValidationError):
remote_adapter_config = None
else:
remote_adapter_config = None
@ -175,7 +182,8 @@ class PreTrainedModelWrapper(nn.Module):
if local_adapter_present:
trained_adapter_config = PeftConfig.from_pretrained(pretrained_model_name_or_path)
else:
trained_adapter_config = PeftConfig.from_pretrained(remote_adapter_config)
remote_adapter_dir = os.path.dirname(remote_adapter_config)
trained_adapter_config = PeftConfig.from_pretrained(remote_adapter_dir)
# Load the pretrained base model
pretrained_model = cls.transformers_parent_class.from_pretrained(
@ -241,17 +249,24 @@ class PreTrainedModelWrapper(nn.Module):
if not os.path.exists(filename):
try:
filename = hf_hub_download(pretrained_model_name_or_path, "pytorch_model.bin")
filename = hf_hub_download(
pretrained_model_name_or_path,
"pytorch_model.bin",
token=token,
)
# sharded
except: # noqa
except (EntryNotFoundError, LocalEntryNotFoundError, HFValidationError):
if os.path.exists(sharded_index_filename):
index_file_name = sharded_index_filename
else:
try:
index_file_name = hf_hub_download(
pretrained_model_name_or_path, "pytorch_model.bin.index.json"
pretrained_model_name_or_path,
"pytorch_model.bin.index.json",
token=token,
)
except ValueError: # not continue training, do not have v_head weight
except (EntryNotFoundError, LocalEntryNotFoundError, HFValidationError):
# not continue training, do not have v_head weight
is_resuming_training = False
logging.warning(
f"A {type(pretrained_model)} model is loaded from '{pretrained_model_name_or_path}', "
@ -267,12 +282,17 @@ class PreTrainedModelWrapper(nn.Module):
if any([module in k for module in cls.supported_modules]):
files_to_download.add(v)
is_shared = True
if is_resuming_training:
if is_shared:
# download each file and add it to the state_dict
state_dict = {}
for shard_file in files_to_download:
filename = hf_hub_download(pretrained_model_name_or_path, shard_file)
filename = hf_hub_download(
pretrained_model_name_or_path,
shard_file,
token=token,
)
state_dict.update(torch.load(filename, map_location="cpu"))
else:
state_dict = torch.load(filename, map_location="cpu")
@ -290,7 +310,7 @@ class PreTrainedModelWrapper(nn.Module):
if not is_peft_model and reward_adapter is not None:
raise ValueError("reward_adapter can only be used with a PeftModel. ")
elif is_peft_model and reward_adapter is not None:
model.add_and_load_reward_modeling_adapter(reward_adapter)
model.add_and_load_reward_modeling_adapter(reward_adapter, token=token)
model.supports_rm_adapter = True
else:
model.supports_rm_adapter = False
@ -400,7 +420,7 @@ class PreTrainedModelWrapper(nn.Module):
"""
raise NotImplementedError
def add_and_load_reward_modeling_adapter(self, adapter_model_id, adapter_name="reward_model_adapter"):
def add_and_load_reward_modeling_adapter(self, adapter_model_id, adapter_name="reward_model_adapter", token=None):
r"""
Add and load a reward modeling adapter. This method can only be used if the
model is a `PeftModel` and if you have initialized the model with the `reward_modeling_adapter_id`
@ -410,7 +430,11 @@ class PreTrainedModelWrapper(nn.Module):
filename = os.path.join(adapter_model_id, "adapter_model.bin")
if not os.path.exists(filename):
try:
local_filename = hf_hub_download(adapter_model_id, "adapter_model.bin")
local_filename = hf_hub_download(
adapter_model_id,
"adapter_model.bin",
token=token,
)
except: # noqa
raise ValueError(
"Could not find adapter model in the Hub, make sure you have the correct adapter model id."
@ -441,7 +465,10 @@ class PreTrainedModelWrapper(nn.Module):
num_labels, hidden_dim = score_dict["weight"].shape
has_bias = any(["bias" in name for name in adapter_state_dict.keys()])
self.score = nn.Linear(hidden_dim, num_labels, bias=has_bias).to(self._get_current_device())
self.score = nn.Linear(hidden_dim, num_labels, bias=has_bias).to(
device=self._get_current_device(),
dtype=self.pretrained_model.dtype,
)
self.score.load_state_dict(score_dict)
# load the adapter to the model

View File

@ -0,0 +1,644 @@
# Copyright 2023 DDPO-pytorch authors (Kevin Black), The HuggingFace Team, metric-space. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import os
import warnings
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import torch
from diffusers import DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.loaders import AttnProcsLayers
from diffusers.models.attention_processor import LoRAAttnProcessor
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import rescale_noise_cfg
from diffusers.utils import randn_tensor
@dataclass
class DDPOPipelineOutput(object):
"""
Output class for the diffusers pipeline to be finetuned with the DDPO trainer
Args:
images (`torch.Tensor`):
The generated images.
latents (`List[torch.Tensor]`):
The latents used to generate the images.
log_probs (`List[torch.Tensor]`):
The log probabilities of the latents.
"""
images: torch.Tensor
latents: torch.Tensor
log_probs: torch.Tensor
@dataclass
class DDPOSchedulerOutput(object):
"""
Output class for the diffusers scheduler to be finetuned with the DDPO trainer
Args:
latents (`torch.Tensor`):
Predicted sample at the previous timestep. Shape: `(batch_size, num_channels, height, width)`
log_probs (`torch.Tensor`):
Log probability of the above mentioned sample. Shape: `(batch_size)`
"""
latents: torch.Tensor
log_probs: torch.Tensor
class DDPOStableDiffusionPipeline(object):
"""
Main class for the diffusers pipeline to be finetuned with the DDPO trainer
"""
def __call__(self, *args, **kwargs) -> DDPOPipelineOutput:
raise NotImplementedError
def scheduler_step(self, *args, **kwargs) -> DDPOSchedulerOutput:
raise NotImplementedError
@property
def unet(self):
"""
Returns the 2d U-Net model used for diffusion.
"""
raise NotImplementedError
@property
def vae(self):
"""
Returns the Variational Autoencoder model used from mapping images to and from the latent space
"""
raise NotImplementedError
@property
def tokenizer(self):
"""
Returns the tokenizer used for tokenizing text inputs
"""
raise NotImplementedError
@property
def scheduler(self):
"""
Returns the scheduler associated with the pipeline used for the diffusion process
"""
raise NotImplementedError
@property
def text_encoder(self):
"""
Returns the text encoder used for encoding text inputs
"""
raise NotImplementedError
@property
def autocast(self):
"""
Returns the autocast context manager
"""
raise NotImplementedError
def set_progress_bar_config(self, *args, **kwargs):
"""
Sets the progress bar config for the pipeline
"""
raise NotImplementedError
def save_pretrained(self, *args, **kwargs):
"""
Saves all of the model weights
"""
raise NotImplementedError
def get_trainable_layers(self, *args, **kwargs):
"""
Returns the trainable parameters of the pipeline
"""
raise NotImplementedError
def save_checkpoint(self, *args, **kwargs):
"""
Light wrapper around accelerate's register_save_state_pre_hook which is run before saving state
"""
raise NotImplementedError
def load_checkpoint(self, *args, **kwargs):
"""
Light wrapper around accelerate's register_lad_state_pre_hook which is run before loading state
"""
raise NotImplementedError
def _left_broadcast(input_tensor, shape):
"""
As opposed to the default direction of broadcasting (right to left), this function broadcasts
from left to right
Args:
input_tensor (`torch.FloatTensor`): is the tensor to broadcast
shape (`Tuple[int]`): is the shape to broadcast to
"""
input_ndim = input_tensor.ndim
if input_ndim > len(shape):
raise ValueError(
"The number of dimensions of the tensor to broadcast cannot be greater than the length of the shape to broadcast to"
)
return input_tensor.reshape(input_tensor.shape + (1,) * (len(shape) - input_ndim)).broadcast_to(shape)
def _get_variance(self, timestep, prev_timestep):
alpha_prod_t = torch.gather(self.alphas_cumprod, 0, timestep.cpu()).to(timestep.device)
alpha_prod_t_prev = torch.where(
prev_timestep.cpu() >= 0,
self.alphas_cumprod.gather(0, prev_timestep.cpu()),
self.final_alpha_cumprod,
).to(timestep.device)
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
return variance
def scheduler_step(
self,
model_output: torch.FloatTensor,
timestep: int,
sample: torch.FloatTensor,
eta: float = 0.0,
use_clipped_model_output: bool = False,
generator=None,
prev_sample: Optional[torch.FloatTensor] = None,
) -> DDPOSchedulerOutput:
"""
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process.
eta (`float`): weight of noise for added noise in diffusion step.
use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped
predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when
`self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would
coincide with the one provided as input and `use_clipped_model_output` will have not effect.
generator: random number generator.
variance_noise (`torch.FloatTensor`): instead of generating noise for the variance using `generator`, we
can directly provide the noise for the variance itself. This is useful for methods such as
CycleDiffusion. (https://arxiv.org/abs/2210.05559)
Returns:
`DDPOSchedulerOutput`: the predicted sample at the previous timestep and the log probability of the sample
"""
if self.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding
# Notation (<variable name> -> <name in paper>
# - pred_noise_t -> e_theta(x_t, t)
# - pred_original_sample -> f_theta(x_t, t) or x_0
# - std_dev_t -> sigma_t
# - eta -> η
# - pred_sample_direction -> "direction pointing to x_t"
# - pred_prev_sample -> "x_t-1"
# 1. get previous step value (=t-1)
prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
# to prevent OOB on gather
prev_timestep = torch.clamp(prev_timestep, 0, self.config.num_train_timesteps - 1)
# 2. compute alphas, betas
alpha_prod_t = self.alphas_cumprod.gather(0, timestep.cpu())
alpha_prod_t_prev = torch.where(
prev_timestep.cpu() >= 0,
self.alphas_cumprod.gather(0, prev_timestep.cpu()),
self.final_alpha_cumprod,
)
alpha_prod_t = _left_broadcast(alpha_prod_t, sample.shape).to(sample.device)
alpha_prod_t_prev = _left_broadcast(alpha_prod_t_prev, sample.shape).to(sample.device)
beta_prod_t = 1 - alpha_prod_t
# 3. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
if self.config.prediction_type == "epsilon":
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
pred_epsilon = model_output
elif self.config.prediction_type == "sample":
pred_original_sample = model_output
pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
elif self.config.prediction_type == "v_prediction":
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
" `v_prediction`"
)
# 4. Clip or threshold "predicted x_0"
if self.config.thresholding:
pred_original_sample = self._threshold_sample(pred_original_sample)
elif self.config.clip_sample:
pred_original_sample = pred_original_sample.clamp(
-self.config.clip_sample_range, self.config.clip_sample_range
)
# 5. compute variance: "sigma_t(η)" -> see formula (16)
# σ_t = sqrt((1 α_t1)/(1 α_t)) * sqrt(1 α_t/α_t1)
variance = _get_variance(self, timestep, prev_timestep)
std_dev_t = eta * variance ** (0.5)
std_dev_t = _left_broadcast(std_dev_t, sample.shape).to(sample.device)
if use_clipped_model_output:
# the pred_epsilon is always re-derived from the clipped x_0 in Glide
pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
prev_sample_mean = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
if prev_sample is not None and generator is not None:
raise ValueError(
"Cannot pass both generator and prev_sample. Please make sure that either `generator` or"
" `prev_sample` stays `None`."
)
if prev_sample is None:
variance_noise = randn_tensor(
model_output.shape,
generator=generator,
device=model_output.device,
dtype=model_output.dtype,
)
prev_sample = prev_sample_mean + std_dev_t * variance_noise
# log prob of prev_sample given prev_sample_mean and std_dev_t
log_prob = (
-((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * (std_dev_t**2))
- torch.log(std_dev_t)
- torch.log(torch.sqrt(2 * torch.as_tensor(np.pi)))
)
# mean along all but batch dimension
log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
return DDPOSchedulerOutput(prev_sample.type(sample.dtype), log_prob)
# 1. The output type for call is different as the logprobs are now returned
# 2. An extra method called `scheduler_step` is added which is used to constraint the scheduler output
@torch.no_grad()
def pipeline_step(
self,
prompt: Optional[Union[str, List[str]]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance_rescale: float = 0.0,
):
r"""
Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
guidance_rescale (`float`, *optional*, defaults to 0.7):
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
Guidance rescale factor should fix overexposure when using zero terminal SNR.
Examples:
Returns:
`DDPOPipelineOutput`: The generated image, the predicted latents used to generate the image and the associated log probabilities
"""
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
height,
width,
callback_steps,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
)
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_encoder_lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
prompt_embeds = self._encode_prompt(
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
)
# 4. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 5. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 6. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
all_latents = [latents]
all_log_probs = []
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# predict the noise residual
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
return_dict=False,
)[0]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
if do_classifier_free_guidance and guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
scheduler_output = scheduler_step(self.scheduler, noise_pred, t, latents, eta)
latents = scheduler_output.latents
log_prob = scheduler_output.log_probs
all_latents.append(latents)
all_log_probs.append(log_prob)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents
has_nsfw_concept = None
if has_nsfw_concept is None:
do_denormalize = [True] * image.shape[0]
else:
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
# Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.final_offload_hook.offload()
return DDPOPipelineOutput(image, all_latents, all_log_probs)
class DefaultDDPOStableDiffusionPipeline(DDPOStableDiffusionPipeline):
def __init__(self, pretrained_model_name: str, *, pretrained_model_revision: str = "main", use_lora: bool = True):
self.sd_pipeline = StableDiffusionPipeline.from_pretrained(
pretrained_model_name, revision=pretrained_model_revision
)
self.use_lora = use_lora
self.pretrained_model = pretrained_model_name
self.pretrained_revision = pretrained_model_revision
try:
self.sd_pipeline.unet.load_attn_procs(pretrained_model_name, revision=pretrained_model_revision)
self.use_lora = True
except OSError:
if use_lora:
warnings.warn(
"If you are aware that the pretrained model has no lora weights to it, ignore this message. "
"Otherwise please check the if `pytorch_lora_weights.safetensors` exists in the model folder."
)
self.sd_pipeline.scheduler = DDIMScheduler.from_config(self.sd_pipeline.scheduler.config)
self.sd_pipeline.safety_checker = None
# memory optimization
self.sd_pipeline.vae.requires_grad_(False)
self.sd_pipeline.text_encoder.requires_grad_(False)
self.sd_pipeline.unet.requires_grad_(not self.use_lora)
def __call__(self, *args, **kwargs) -> DDPOPipelineOutput:
return pipeline_step(self.sd_pipeline, *args, **kwargs)
def scheduler_step(self, *args, **kwargs) -> DDPOSchedulerOutput:
return scheduler_step(self.sd_pipeline.scheduler, *args, **kwargs)
@property
def unet(self):
return self.sd_pipeline.unet
@property
def vae(self):
return self.sd_pipeline.vae
@property
def tokenizer(self):
return self.sd_pipeline.tokenizer
@property
def scheduler(self):
return self.sd_pipeline.scheduler
@property
def text_encoder(self):
return self.sd_pipeline.text_encoder
@property
def autocast(self):
return contextlib.nullcontext if self.use_lora else None
def save_pretrained(self, output_dir):
if self.use_lora:
self.sd_pipeline.unet.save_attn_procs(output_dir)
self.sd_pipeline.save_pretrained(output_dir)
def set_progress_bar_config(self, *args, **kwargs):
self.sd_pipeline.set_progress_bar_config(*args, **kwargs)
def get_trainable_layers(self):
if self.use_lora:
# Set correct lora layers
lora_attn_procs = {}
for name in self.sd_pipeline.unet.attn_processors.keys():
cross_attention_dim = (
None if name.endswith("attn1.processor") else self.sd_pipeline.unet.config.cross_attention_dim
)
if name.startswith("mid_block"):
hidden_size = self.sd_pipeline.unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(self.sd_pipeline.unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = self.sd_pipeline.unet.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRAAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
)
self.sd_pipeline.unet.set_attn_processor(lora_attn_procs)
return AttnProcsLayers(self.sd_pipeline.unet.attn_processors)
else:
return self.sd_pipeline.unet
def save_checkpoint(self, models, weights, output_dir):
if len(models) != 1:
raise ValueError("Given how the trainable params were set, this should be of length 1")
if self.use_lora and isinstance(models[0], AttnProcsLayers):
self.sd_pipeline.unet.save_attn_procs(output_dir)
elif not self.use_lora and isinstance(models[0], UNet2DConditionModel):
models[0].save_pretrained(os.path.join(output_dir, "unet"))
else:
raise ValueError(f"Unknown model type {type(models[0])}")
def load_checkpoint(self, models, input_dir):
if len(models) != 1:
raise ValueError("Given how the trainable params were set, this should be of length 1")
if self.use_lora and isinstance(models[0], AttnProcsLayers):
tmp_unet = UNet2DConditionModel.from_pretrained(
self.pretrained_model,
revision=self.pretrained_revision,
subfolder="unet",
)
tmp_unet.load_attn_procs(input_dir)
models[0].load_state_dict(AttnProcsLayers(tmp_unet.attn_processors).state_dict())
del tmp_unet
elif not self.use_lora and isinstance(models[0], UNet2DConditionModel):
load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
models[0].register_to_config(**load_model.config)
models[0].load_state_dict(load_model.state_dict())
del load_model
else:
raise ValueError(f"Unknown model type {type(models[0])}")

View File

@ -16,11 +16,25 @@
# There is a circular import in the PPOTrainer if we let isort sort these
# isort: off
from .utils import AdaptiveKLController, FixedKLController, ConstantLengthDataset, DataCollatorForCompletionOnlyLM
from .utils import (
AdaptiveKLController,
FixedKLController,
ConstantLengthDataset,
DataCollatorForCompletionOnlyLM,
RunningMoments,
disable_dropout_in_model,
)
# isort: on
from ..import_utils import is_diffusers_available
from .base import BaseTrainer
from .ddpo_config import DDPOConfig
if is_diffusers_available():
from .ddpo_trainer import DDPOTrainer
from .dpo_trainer import DPOTrainer
from .ppo_config import PPOConfig
from .ppo_trainer import PPOTrainer

140
trl/trainer/ddpo_config.py Normal file
View File

@ -0,0 +1,140 @@
import warnings
from dataclasses import dataclass, field
from typing import Optional
from ..core import flatten_dict
from ..import_utils import is_bitsandbytes_available, is_torchvision_available
@dataclass
class DDPOConfig(object):
"""
Configuration class for DDPOTrainer
"""
run_name: Optional[str] = field(
default="",
metadata={"help": "Run name for wandb logging and checkpoint saving."},
)
seed: Optional[int] = field(default=42, metadata={"help": "Random seed for reproducibility."})
logdir: Optional[str] = field(
default="logs",
metadata={"help": "Top-level logging directory for checkpoint saving."},
)
log_with: Optional[str] = field(
default=None,
metadata={
"help": "Log with either 'wandb' or 'tensorboard', check https://huggingface.co/docs/accelerate/usage_guides/tracking for more details"
},
)
tracker_kwargs: Optional[dict] = field(
default_factory=dict,
metadata={"help": "Keyword arguments for the tracker (e.g. wandb_project)"},
)
accelerator_kwargs: Optional[dict] = field(
default_factory=dict,
metadata={"help": "Keyword arguments for the accelerator"},
)
project_kwargs: Optional[dict] = field(
default_factory=dict,
metadata={"help": "Keyword arguments for the accelerator project config (e.g. `logging_dir`)"},
)
tracker_project_name: Optional[str] = field(
default="trl", metadata={"help": "Name of project to use for tracking"}
)
num_epochs: Optional[int] = field(default=100, metadata={"help": "Number of epochs to train."})
save_freq: Optional[int] = field(
default=1,
metadata={"help": "Number of epochs between saving model checkpoints."},
)
num_checkpoint_limit: Optional[int] = field(
default=5,
metadata={"help": "Number of checkpoints to keep before overwriting old ones."},
)
mixed_precision: Optional[str] = field(default="fp16", metadata={"help": "Mixed precision training."})
allow_tf32: Optional[bool] = field(default=True, metadata={"help": "Allow tf32 on Ampere GPUs."})
resume_from: Optional[str] = field(default="", metadata={"help": "Resume training from a checkpoint."})
sample_num_steps: Optional[int] = field(default=50, metadata={"help": "Number of sampler inference steps."})
sample_eta: Optional[float] = field(default=1.0, metadata={"help": "Eta parameter for the DDIM sampler."})
sample_guidance_scale: Optional[float] = field(default=5.0, metadata={"help": "Classifier-free guidance weight."})
sample_batch_size: Optional[int] = field(
default=1, metadata={"help": "Batch size (per GPU!) to use for sampling."}
)
sample_num_batches_per_epoch: Optional[int] = field(
default=2, metadata={"help": "Number of batches to sample per epoch."}
)
train_batch_size: Optional[int] = field(default=1, metadata={"help": "Batch size (per GPU!) to use for training."})
train_use_8bit_adam: Optional[bool] = field(
default=False,
metadata={"help": "Whether to use the 8bit Adam optimizer from bitsandbytes."},
)
train_learning_rate: Optional[float] = field(default=3e-4, metadata={"help": "Learning rate."})
train_adam_beta1: Optional[float] = field(default=0.9, metadata={"help": "Adam beta1."})
train_adam_beta2: Optional[float] = field(default=0.999, metadata={"help": "Adam beta2."})
train_adam_weight_decay: Optional[float] = field(default=1e-4, metadata={"help": "Adam weight decay."})
train_adam_epsilon: Optional[float] = field(default=1e-8, metadata={"help": "Adam epsilon."})
train_gradient_accumulation_steps: Optional[int] = field(
default=1, metadata={"help": "Number of gradient accumulation steps."}
)
train_max_grad_norm: Optional[float] = field(
default=1.0, metadata={"help": "Maximum gradient norm for gradient clipping."}
)
train_num_inner_epochs: Optional[int] = field(
default=1, metadata={"help": "Number of inner epochs per outer epoch."}
)
train_cfg: Optional[bool] = field(
default=True,
metadata={"help": "Whether or not to use classifier-free guidance during training."},
)
train_adv_clip_max: Optional[float] = field(default=5, metadata={"help": "Clip advantages to the range."})
train_clip_range: Optional[float] = field(default=1e-4, metadata={"help": "The PPO clip range."})
train_timestep_fraction: Optional[float] = field(
default=1.0, metadata={"help": "The fraction of timesteps to train on."}
)
per_prompt_stat_tracking: Optional[bool] = field(
default=False,
metadata={"help": "Whether to track statistics for each prompt separately."},
)
per_prompt_stat_tracking_buffer_size: Optional[int] = field(
default=16,
metadata={"help": "Number of reward values to store in the buffer for each prompt."},
)
per_prompt_stat_tracking_min_count: Optional[int] = field(
default=16,
metadata={"help": "The minimum number of reward values to store in the buffer."},
)
async_reward_computation: Optional[bool] = field(
default=False,
metadata={"help": "Whether to compute rewards asynchronously."},
)
max_workers: Optional[int] = field(
default=2,
metadata={"help": "The maximum number of workers to use for async reward computation."},
)
negative_prompts: Optional[str] = field(
default="",
metadata={"help": "Comma-separated list of prompts to use as negative examples."},
)
def to_dict(self):
output_dict = {}
for key, value in self.__dict__.items():
output_dict[key] = value
return flatten_dict(output_dict)
def __post_init__(self):
if self.log_with not in ["wandb", "tensorboard"]:
warnings.warn(
("Accelerator tracking only supports image logging if `log_with` is set to 'wandb' or 'tensorboard'.")
)
if self.log_with == "wandb" and not is_torchvision_available():
warnings.warn("Wandb image logging requires torchvision to be installed")
if self.train_use_8bit_adam and not is_bitsandbytes_available():
raise ImportError(
"You need to install bitsandbytes to use 8bit Adam. "
"You can install it with `pip install bitsandbytes`."
)

576
trl/trainer/ddpo_trainer.py Normal file
View File

@ -0,0 +1,576 @@
# Copyright 2023 DDPO-pytorch authors (Kevin Black), metric-space, The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from collections import defaultdict
from concurrent import futures
from typing import Any, Callable, Optional, Tuple
from warnings import warn
import torch
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from ..models import DDPOStableDiffusionPipeline
from . import BaseTrainer, DDPOConfig
from .utils import PerPromptStatTracker
logger = get_logger(__name__)
class DDPOTrainer(BaseTrainer):
"""
The DDPOTrainer uses Deep Diffusion Policy Optimization to optimise diffusion models.
Note, this trainer is heavily inspired by the work here: https://github.com/kvablack/ddpo-pytorch
As of now only Stable Diffusion based pipelines are supported
Attributes:
**config** (`DDPOConfig`) -- Configuration object for DDPOTrainer. Check the documentation of `PPOConfig` for more
details.
**reward_function** (Callable[[torch.Tensor, Tuple[str], Tuple[Any]], torch.Tensor]) -- Reward function to be used
**prompt_function** (Callable[[], Tuple[str, Any]]) -- Function to generate prompts to guide model
**sd_pipeline** (`DDPOStableDiffusionPipeline`) -- Stable Diffusion pipeline to be used for training.
**image_samples_hook** (Optional[Callable[[Any, Any, Any], Any]]) -- Hook to be called to log images
"""
def __init__(
self,
config: DDPOConfig,
reward_function: Callable[[torch.Tensor, Tuple[str], Tuple[Any]], torch.Tensor],
prompt_function: Callable[[], Tuple[str, Any]],
sd_pipeline: DDPOStableDiffusionPipeline,
image_samples_hook: Optional[Callable[[Any, Any, Any], Any]] = None,
):
if image_samples_hook is None:
warn("No image_samples_hook provided; no images will be logged")
self.prompt_fn = prompt_function
self.reward_fn = reward_function
self.config = config
self.image_samples_callback = image_samples_hook
accelerator_project_config = ProjectConfiguration(**self.config.project_kwargs)
if self.config.resume_from:
self.config.resume_from = os.path.normpath(os.path.expanduser(self.config.resume_from))
if "checkpoint_" not in os.path.basename(self.config.resume_from):
# get the most recent checkpoint in this directory
checkpoints = list(
filter(
lambda x: "checkpoint_" in x,
os.listdir(self.config.resume_from),
)
)
if len(checkpoints) == 0:
raise ValueError(f"No checkpoints found in {self.config.resume_from}")
checkpoint_numbers = sorted([int(x.split("_")[-1]) for x in checkpoints])
self.config.resume_from = os.path.join(
self.config.resume_from,
f"checkpoint_{checkpoint_numbers[-1]}",
)
accelerator_project_config.iteration = checkpoint_numbers[-1] + 1
# number of timesteps within each trajectory to train on
self.num_train_timesteps = int(self.config.sample_num_steps * self.config.train_timestep_fraction)
self.accelerator = Accelerator(
log_with=self.config.log_with,
mixed_precision=self.config.mixed_precision,
project_config=accelerator_project_config,
# we always accumulate gradients across timesteps; we want config.train.gradient_accumulation_steps to be the
# number of *samples* we accumulate across, so we need to multiply by the number of training timesteps to get
# the total number of optimizer steps to accumulate across.
gradient_accumulation_steps=self.config.train_gradient_accumulation_steps * self.num_train_timesteps,
**self.config.accelerator_kwargs,
)
is_okay, message = self._config_check()
if not is_okay:
raise ValueError(message)
is_using_tensorboard = config.log_with is not None and config.log_with == "tensorboard"
if self.accelerator.is_main_process:
self.accelerator.init_trackers(
self.config.tracker_project_name,
config=dict(ddpo_trainer_config=config.to_dict()) if not is_using_tensorboard else config.to_dict(),
init_kwargs=self.config.tracker_kwargs,
)
logger.info(f"\n{config}")
set_seed(self.config.seed, device_specific=True)
self.sd_pipeline = sd_pipeline
self.sd_pipeline.set_progress_bar_config(
position=1,
disable=not self.accelerator.is_local_main_process,
leave=False,
desc="Timestep",
dynamic_ncols=True,
)
# For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
if self.accelerator.mixed_precision == "fp16":
inference_dtype = torch.float16
elif self.accelerator.mixed_precision == "bf16":
inference_dtype = torch.bfloat16
else:
inference_dtype = torch.float32
self.sd_pipeline.vae.to(self.accelerator.device, dtype=inference_dtype)
self.sd_pipeline.text_encoder.to(self.accelerator.device, dtype=inference_dtype)
self.sd_pipeline.unet.to(self.accelerator.device, dtype=inference_dtype)
trainable_layers = self.sd_pipeline.get_trainable_layers()
self.accelerator.register_save_state_pre_hook(self._save_model_hook)
self.accelerator.register_load_state_pre_hook(self._load_model_hook)
# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if self.config.allow_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
self.optimizer = self._setup_optimizer(trainable_layers.parameters())
self.neg_prompt_embed = self.sd_pipeline.text_encoder(
self.sd_pipeline.tokenizer(
[""] if self.config.negative_prompts is None else self.config.negative_prompts,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=self.sd_pipeline.tokenizer.model_max_length,
).input_ids.to(self.accelerator.device)
)[0]
if config.per_prompt_stat_tracking:
self.stat_tracker = PerPromptStatTracker(
config.per_prompt_stat_tracking_buffer_size,
config.per_prompt_stat_tracking_min_count,
)
# NOTE: for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses
# more memory
self.autocast = self.sd_pipeline.autocast or self.accelerator.autocast
self.trainable_layers, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
if self.config.async_reward_computation:
self.executor = futures.ThreadPoolExecutor(max_workers=config.max_workers)
if config.resume_from:
logger.info(f"Resuming from {config.resume_from}")
self.accelerator.load_state(config.resume_from)
self.first_epoch = int(config.resume_from.split("_")[-1]) + 1
else:
self.first_epoch = 0
def compute_rewards(self, prompt_image_pairs, is_async=False):
if not is_async:
rewards = []
for images, prompts, prompt_metadata in prompt_image_pairs:
reward, reward_metadata = self.reward_fn(images, prompts, prompt_metadata)
rewards.append(
(
torch.as_tensor(reward, device=self.accelerator.device),
reward_metadata,
)
)
else:
rewards = self.executor.map(lambda x: self.reward_fn(*x), prompt_image_pairs)
rewards = [
(torch.as_tensor(reward.result(), device=self.accelerator.device), reward_metadata.result())
for reward, reward_metadata in rewards
]
return zip(*rewards)
def step(self, epoch: int, global_step: int):
"""
Perform a single step of training.
Args:
epoch (int): The current epoch.
global_step (int): The current global step.
Side Effects:
- Model weights are updated
- Logs the statistics to the accelerator trackers.
- If `self.image_samples_callback` is not None, it will be called with the prompt_image_pairs, global_step, and the accelerator tracker.
Returns:
global_step (int): The updated global step.
"""
samples, prompt_image_data = self._generate_samples(
iterations=self.config.sample_num_batches_per_epoch,
batch_size=self.config.sample_batch_size,
)
# collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...)
samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()}
rewards, rewards_metadata = self.compute_rewards(
prompt_image_data, is_async=self.config.async_reward_computation
)
for i, image_data in enumerate(prompt_image_data):
image_data.extend([rewards[i], rewards_metadata[i]])
if self.image_samples_callback is not None:
self.image_samples_callback(prompt_image_data, global_step, self.accelerator.trackers[0])
rewards = torch.cat(rewards)
rewards = self.accelerator.gather(rewards).cpu().numpy()
self.accelerator.log(
{
"reward": rewards,
"epoch": epoch,
"reward_mean": rewards.mean(),
"reward_std": rewards.std(),
},
step=global_step,
)
if self.config.per_prompt_stat_tracking:
# gather the prompts across processes
prompt_ids = self.accelerator.gather(samples["prompt_ids"]).cpu().numpy()
prompts = self.sd_pipeline.tokenizer.batch_decode(prompt_ids, skip_special_tokens=True)
advantages = self.stat_tracker.update(prompts, rewards)
else:
advantages = (rewards - rewards.mean()) / (rewards.std() + 1e-8)
# ungather advantages; keep the entries corresponding to the samples on this process
samples["advantages"] = (
torch.as_tensor(advantages)
.reshape(self.accelerator.num_processes, -1)[self.accelerator.process_index]
.to(self.accelerator.device)
)
del samples["prompt_ids"]
total_batch_size, num_timesteps = samples["timesteps"].shape
for inner_epoch in range(self.config.train_num_inner_epochs):
# shuffle samples along batch dimension
perm = torch.randperm(total_batch_size, device=self.accelerator.device)
samples = {k: v[perm] for k, v in samples.items()}
# shuffle along time dimension independently for each sample
# still trying to understand the code below
perms = torch.stack(
[torch.randperm(num_timesteps, device=self.accelerator.device) for _ in range(total_batch_size)]
)
for key in ["timesteps", "latents", "next_latents", "log_probs"]:
samples[key] = samples[key][
torch.arange(total_batch_size, device=self.accelerator.device)[:, None],
perms,
]
original_keys = samples.keys()
original_values = samples.values()
# rebatch them as user defined train_batch_size is different from sample_batch_size
reshaped_values = [v.reshape(-1, self.config.train_batch_size, *v.shape[1:]) for v in original_values]
# Transpose the list of original values
transposed_values = zip(*reshaped_values)
# Create new dictionaries for each row of transposed values
samples_batched = [dict(zip(original_keys, row_values)) for row_values in transposed_values]
self.sd_pipeline.unet.train()
global_step = self._train_batched_samples(inner_epoch, epoch, global_step, samples_batched)
# ensure optimization step at the end of the inner epoch
if not self.accelerator.sync_gradients:
raise ValueError(
"Optimization step should have been performed by this point. Please check calculated gradient accumulation settings."
)
if epoch != 0 and epoch % self.config.save_freq == 0 and self.accelerator.is_main_process:
self.accelerator.save_state()
return global_step
def calculate_loss(self, latents, timesteps, next_latents, log_probs, advantages, embeds):
"""
Calculate the loss for a batch of an unpacked sample
Args:
latents (torch.Tensor):
The latents sampled from the diffusion model, shape: [batch_size, num_steps, ...]
timesteps (torch.Tensor):
The timesteps sampled from the diffusion model, shape: [batch_size]
next_latents (torch.Tensor):
The next latents sampled from the diffusion model, shape: [batch_size, num_steps, ...]
log_probs (torch.Tensor):
The log probabilities of the latents, shape: [batch_size]
advantages (torch.Tensor):
The advantages of the latents, shape: [batch_size]
embeds (torch.Tensor):
The embeddings of the prompts, shape: [2*batch_size or batch_size, ...]
Note: the "or" is because if train_cfg is True, the expectation is that negative prompts are concatenated to the embeds
Returns:
loss (torch.Tensor), approx_kl (torch.Tensor), clipfrac (torch.Tensor)
(all of these are of shape (1,))
"""
with self.autocast():
if self.config.train_cfg:
noise_pred = self.sd_pipeline.unet(
torch.cat([latents] * 2),
torch.cat([timesteps] * 2),
embeds,
).sample
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.config.sample_guidance_scale * (
noise_pred_text - noise_pred_uncond
)
else:
noise_pred = self.sd_pipeline.unet(
latents,
timesteps,
embeds,
).sample
# compute the log prob of next_latents given latents under the current model
scheduler_step_output = self.sd_pipeline.scheduler_step(
noise_pred,
timesteps,
latents,
eta=self.config.sample_eta,
prev_sample=next_latents,
)
log_prob = scheduler_step_output.log_probs
advantages = torch.clamp(
advantages,
-self.config.train_adv_clip_max,
self.config.train_adv_clip_max,
)
ratio = torch.exp(log_prob - log_probs)
loss = self.loss(advantages, self.config.train_clip_range, ratio)
approx_kl = 0.5 * torch.mean((log_prob - log_probs) ** 2)
clipfrac = torch.mean((torch.abs(ratio - 1.0) > self.config.train_clip_range).float())
return loss, approx_kl, clipfrac
def loss(
self,
advantages: torch.Tensor,
clip_range: float,
ratio: torch.Tensor,
):
unclipped_loss = -advantages * ratio
clipped_loss = -advantages * torch.clamp(
ratio,
1.0 - clip_range,
1.0 + clip_range,
)
return torch.mean(torch.maximum(unclipped_loss, clipped_loss))
def _setup_optimizer(self, trainable_layers_parameters):
if self.config.train_use_8bit_adam:
import bitsandbytes
optimizer_cls = bitsandbytes.optim.AdamW8bit
else:
optimizer_cls = torch.optim.AdamW
return optimizer_cls(
trainable_layers_parameters,
lr=self.config.train_learning_rate,
betas=(self.config.train_adam_beta1, self.config.train_adam_beta2),
weight_decay=self.config.train_adam_weight_decay,
eps=self.config.train_adam_epsilon,
)
def _save_model_hook(self, models, weights, output_dir):
self.sd_pipeline.save_checkpoint(models, weights, output_dir)
weights.pop() # ensures that accelerate doesn't try to handle saving of the model
def _load_model_hook(self, models, input_dir):
self.sd_pipeline.load_checkpoint(models, input_dir)
models.pop() # ensures that accelerate doesn't try to handle loading of the model
def _generate_samples(self, iterations, batch_size):
"""
Generate samples from the model
Args:
iterations (int): Number of iterations to generate samples for
batch_size (int): Batch size to use for sampling
Returns:
samples (List[Dict[str, torch.Tensor]]), prompt_image_pairs (List[List[Any]])
"""
samples = []
prompt_image_pairs = []
self.sd_pipeline.unet.eval()
sample_neg_prompt_embeds = self.neg_prompt_embed.repeat(batch_size, 1, 1)
for _ in range(iterations):
prompts, prompt_metadata = zip(*[self.prompt_fn() for _ in range(batch_size)])
prompt_ids = self.sd_pipeline.tokenizer(
prompts,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=self.sd_pipeline.tokenizer.model_max_length,
).input_ids.to(self.accelerator.device)
prompt_embeds = self.sd_pipeline.text_encoder(prompt_ids)[0]
with self.autocast():
sd_output = self.sd_pipeline(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=sample_neg_prompt_embeds,
num_inference_steps=self.config.sample_num_steps,
guidance_scale=self.config.sample_guidance_scale,
eta=self.config.sample_eta,
output_type="pt",
)
images = sd_output.images
latents = sd_output.latents
log_probs = sd_output.log_probs
latents = torch.stack(latents, dim=1) # (batch_size, num_steps + 1, ...)
log_probs = torch.stack(log_probs, dim=1) # (batch_size, num_steps, 1)
timesteps = self.sd_pipeline.scheduler.timesteps.repeat(batch_size, 1) # (batch_size, num_steps)
samples.append(
{
"prompt_ids": prompt_ids,
"prompt_embeds": prompt_embeds,
"timesteps": timesteps,
"latents": latents[:, :-1], # each entry is the latent before timestep t
"next_latents": latents[:, 1:], # each entry is the latent after timestep t
"log_probs": log_probs,
"negative_prompt_embeds": sample_neg_prompt_embeds,
}
)
prompt_image_pairs.append([images, prompts, prompt_metadata])
return samples, prompt_image_pairs
def _train_batched_samples(self, inner_epoch, epoch, global_step, batched_samples):
"""
Train on a batch of samples. Main training segment
Args:
inner_epoch (int): The current inner epoch
epoch (int): The current epoch
global_step (int): The current global step
batched_samples (List[Dict[str, torch.Tensor]]): The batched samples to train on
Side Effects:
- Model weights are updated
- Logs the statistics to the accelerator trackers.
Returns:
global_step (int): The updated global step
"""
info = defaultdict(list)
for i, sample in enumerate(batched_samples):
if self.config.train_cfg:
# concat negative prompts to sample prompts to avoid two forward passes
embeds = torch.cat([sample["negative_prompt_embeds"], sample["prompt_embeds"]])
else:
embeds = sample["prompt_embeds"]
for j in range(self.num_train_timesteps):
with self.accelerator.accumulate(self.sd_pipeline.unet):
loss, approx_kl, clipfrac = self.calculate_loss(
sample["latents"][:, j],
sample["timesteps"][:, j],
sample["next_latents"][:, j],
sample["log_probs"][:, j],
sample["advantages"],
embeds,
)
info["approx_kl"].append(approx_kl)
info["clipfrac"].append(clipfrac)
info["loss"].append(loss)
self.accelerator.backward(loss)
if self.accelerator.sync_gradients:
self.accelerator.clip_grad_norm_(
self.trainable_layers.parameters(),
self.config.train_max_grad_norm,
)
self.optimizer.step()
self.optimizer.zero_grad()
# Checks if the accelerator has performed an optimization step behind the scenes
if self.accelerator.sync_gradients:
# log training-related stuff
info = {k: torch.mean(torch.stack(v)) for k, v in info.items()}
info = self.accelerator.reduce(info, reduction="mean")
info.update({"epoch": epoch, "inner_epoch": inner_epoch})
self.accelerator.log(info, step=global_step)
global_step += 1
info = defaultdict(list)
return global_step
def _config_check(self) -> Tuple[bool, str]:
samples_per_epoch = (
self.config.sample_batch_size * self.accelerator.num_processes * self.config.sample_num_batches_per_epoch
)
total_train_batch_size = (
self.config.train_batch_size
* self.accelerator.num_processes
* self.config.train_gradient_accumulation_steps
)
if not self.config.sample_batch_size >= self.config.train_batch_size:
return (
False,
f"Sample batch size ({self.config.sample_batch_size}) must be greater than or equal to the train batch size ({self.config.train_batch_size})",
)
if not self.config.sample_batch_size % self.config.train_batch_size == 0:
return (
False,
f"Sample batch size ({self.config.sample_batch_size}) must be divisible by the train batch size ({self.config.train_batch_size})",
)
if not samples_per_epoch % total_train_batch_size == 0:
return (
False,
f"Number of samples per epoch ({samples_per_epoch}) must be divisible by the total train batch size ({total_train_batch_size})",
)
return True, ""
def train(self, epochs: Optional[int] = None):
"""
Train the model for a given number of epochs
"""
global_step = 0
if epochs is None:
epochs = self.config.num_epochs
for epoch in range(self.first_epoch, epochs):
global_step = self.step(epoch, global_step)
def _save_pretrained(self, save_directory):
self.sd_pipeline.save_pretrained(save_directory)

View File

@ -24,11 +24,12 @@ from transformers import DataCollator, PreTrainedModel, PreTrainedTokenizerBase,
from transformers.trainer_callback import TrainerCallback
from ..import_utils import is_peft_available
from .utils import DPODataCollatorWithPadding, pad_to_length
from ..models import create_reference_model
from .utils import DPODataCollatorWithPadding, disable_dropout_in_model, pad_to_length
if is_peft_available():
from peft import get_peft_model, prepare_model_for_int8_training
from peft import PeftModel, get_peft_model, prepare_model_for_int8_training
class DPOTrainer(Trainer):
@ -39,7 +40,8 @@ class DPOTrainer(Trainer):
model (`transformers.PreTrainedModel`):
The model to train, preferably an `AutoModelForSequenceClassification`.
ref_model (`PreTrainedModelWrapper`):
Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss.
Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no
reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized.
beta (`float`, defaults to 0.1):
The beta factor in DPO loss. Higher beta means less divergence from the initial policy.
args (`transformers.TrainingArguments`):
@ -73,12 +75,14 @@ class DPOTrainer(Trainer):
The maximum length of the prompt. This argument is required if you want to use the default data collator.
peft_config (`Dict`, defaults to `None`):
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
disable_dropout (`bool`, defaults to `True`):
Whether or not to disable dropouts in `model` and `ref_model`.
"""
def __init__(
self,
model: Union[PreTrainedModel, nn.Module] = None,
ref_model: Union[PreTrainedModel, nn.Module] = None,
ref_model: Optional[Union[PreTrainedModel, nn.Module]] = None,
beta: float = 0.1,
args: TrainingArguments = None,
data_collator: Optional[DataCollator] = None,
@ -98,6 +102,7 @@ class DPOTrainer(Trainer):
max_length: Optional[int] = None,
max_prompt_length: Optional[int] = None,
peft_config: Optional[Dict] = None,
disable_dropout: bool = True,
):
if not is_peft_available() and peft_config is not None:
raise ValueError(
@ -108,6 +113,16 @@ class DPOTrainer(Trainer):
model = prepare_model_for_int8_training(model)
model = get_peft_model(model, peft_config)
self.is_peft_model = is_peft_available() and isinstance(model, PeftModel)
if ref_model:
self.ref_model = ref_model
elif self.is_peft_model:
# The `model` with adapters turned off will be used as the reference model
self.ref_model = None
else:
self.ref_model = create_reference_model(model)
if data_collator is None:
if tokenizer is None:
raise ValueError(
@ -150,11 +165,15 @@ class DPOTrainer(Trainer):
else:
self.use_dpo_data_collator = False
if disable_dropout:
disable_dropout_in_model(model)
if self.ref_model is not None:
disable_dropout_in_model(self.ref_model)
self.label_pad_token_id = label_pad_token_id
self.padding_value = padding_value
self.beta = beta
self.ref_model = ref_model
self._stored_metrics = defaultdict(lambda: defaultdict(list))
@ -172,14 +191,19 @@ class DPOTrainer(Trainer):
preprocess_logits_for_metrics,
)
# Since we inherit from trainer we always have access to an accelerator
if hasattr(self, "accelerator"):
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
else:
if not hasattr(self, "accelerator"):
raise AttributeError(
"Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
)
if self.ref_model is None:
if not hasattr(self.accelerator.unwrap_model(self.model), "disable_adapter"):
raise ValueError(
"You are using a `peft` version that does not support `disable_adapter`. Please update your `peft` version to the latest version."
)
else:
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
def concatenated_inputs(self, batch: Dict[str, Union[List, torch.LongTensor]]) -> Dict[str, torch.LongTensor]:
"""Concatenate the chosen and rejected inputs into a single tensor.
@ -319,12 +343,21 @@ class DPOTrainer(Trainer):
policy_rejected_logits,
) = self.concatenated_forward(model, batch)
with torch.no_grad():
(
reference_chosen_logps,
reference_rejected_logps,
_,
_,
) = self.concatenated_forward(self.ref_model, batch)
if self.ref_model is None:
with self.accelerator.unwrap_model(self.model).disable_adapter():
(
reference_chosen_logps,
reference_rejected_logps,
_,
_,
) = self.concatenated_forward(self.model, batch)
else:
(
reference_chosen_logps,
reference_rejected_logps,
_,
_,
) = self.concatenated_forward(self.ref_model, batch)
losses, chosen_rewards, rejected_rewards = self.dpo_loss(
policy_chosen_logps,
@ -378,13 +411,23 @@ class DPOTrainer(Trainer):
pad_token_id=self.tokenizer.pad_token_id,
)
reference_output = self.ref_model.generate(
batch["prompt_input_ids"],
attention_mask=batch["prompt_attention_mask"],
max_length=self.config.max_length,
do_sample=True,
pad_token_id=self.tokenizer.pad_token_id,
)
if self.ref_model is None:
with self.accelerator.unwrap_model(self.model).disable_adapter():
reference_output = self.model.generate(
batch["prompt_input_ids"],
attention_mask=batch["prompt_attention_mask"],
max_length=self.config.max_length,
do_sample=True,
pad_token_id=self.tokenizer.pad_token_id,
)
else:
reference_output = self.ref_model.generate(
batch["prompt_input_ids"],
attention_mask=batch["prompt_attention_mask"],
max_length=self.config.max_length,
do_sample=True,
pad_token_id=self.tokenizer.pad_token_id,
)
policy_output = pad_to_length(policy_output, self.config.max_length, self.tokenizer.pad_token_id)
policy_output_decoded = self.tokenizer.batch_decode(policy_output, skip_special_tokens=True)

View File

@ -21,6 +21,8 @@ from typing import Optional
import numpy as np
import requests
from trl.trainer.utils import exact_div
from ..core import flatten_dict
@ -162,6 +164,11 @@ class PPOConfig(object):
ratio_threshold: Optional[float] = field(
default=10.0, metadata={"help": "Skip mini-batches with high PPO ratios that can cause loss spikes"}
)
use_score_scaling: Optional[bool] = field(default=False, metadata={"help": "Use score scaling"})
use_score_norm: Optional[bool] = field(
default=False, metadata={"help": "Use score normalization. Only applicable if use_score_scaling is True"}
)
score_clip: Optional[float] = field(default=None, metadata={"help": "Score clipping"})
def __post_init__(self):
if self.forward_batch_size is not None:
@ -170,6 +177,15 @@ class PPOConfig(object):
)
self.mini_batch_size = self.forward_batch_size
self.backward_batch_size = self.mini_batch_size * self.gradient_accumulation_steps
exact_div(
self.batch_size,
self.backward_batch_size,
"`batch_size`",
"`mini_batch_size * gradient_accumulation_steps`",
"`batch_size` must be a multiple of `mini_batch_size * gradient_accumulation_steps`",
)
# check if wandb is installed
if self.log_with == "wandb":
# raise error if wandb is not installed

View File

@ -53,7 +53,7 @@ from ..core import (
)
from ..import_utils import is_torch_greater_2_0
from ..models import SUPPORTED_ARCHITECTURES, PreTrainedModelWrapper, create_reference_model
from . import AdaptiveKLController, BaseTrainer, FixedKLController, PPOConfig
from . import AdaptiveKLController, BaseTrainer, FixedKLController, PPOConfig, RunningMoments
MODEL_CARD_TEMPLATE = """---
@ -66,7 +66,7 @@ tags:
# {model_name}
This is a [TRL language model](https://github.com/lvwerra/trl) that has been fine-tuned with reinforcement learning to
This is a [TRL language model](https://github.com/huggingface/trl) that has been fine-tuned with reinforcement learning to
guide the model outputs according to a value, function, or human feedback. The model can be used for text generation.
## Usage
@ -207,6 +207,7 @@ class PPOTrainer(BaseTrainer):
self.model_params = filter(lambda p: p.requires_grad, self.model.parameters())
self.is_encoder_decoder = hasattr(self.model, "is_encoder_decoder")
self.is_peft_model = getattr(self.model, "is_peft_model", False)
self.is_using_text_environment = getattr(config, "use_text_environment", False)
if isinstance(ref_model, SUPPORTED_ARCHITECTURES):
self.ref_model = ref_model
@ -344,6 +345,8 @@ class PPOTrainer(BaseTrainer):
PPODecorators.optimize_cuda_cache = self.config.optimize_cuda_cache
self.running = RunningMoments(self.accelerator)
def _filter_kwargs(self, kwargs, target_func):
"""
filter the keyword arguments that are supported by the target function.
@ -388,7 +391,7 @@ class PPOTrainer(BaseTrainer):
signature = inspect.signature(self.model.forward)
self._signature_columns = list(signature.parameters.keys())
# label => sentiment | we need query and response for logging purpose
self._signature_columns += list(set(["label", "query", "response"]))
self._signature_columns += ["label", "query", "response"]
# Adapted from transformers.Trainer._remove_unused_columns
def _remove_unused_columns(self, dataset: "Dataset"):
@ -524,6 +527,7 @@ class PPOTrainer(BaseTrainer):
queries: List[torch.LongTensor],
responses: List[torch.LongTensor],
scores: List[torch.FloatTensor],
masks: Optional[List[torch.LongTensor]] = None,
):
"""
Check if the input data is valid for training.
@ -537,6 +541,8 @@ class PPOTrainer(BaseTrainer):
List of tensors containing the encoded responses of shape (`response_length`)
scores (List[`torch.FloatTensor`]):
List of tensors containing the scores.
masks (List[`torch.LongTensor`], *optional*):
list of optional tensors containing the masks of shape (`query_length` + `response_length`)
Returns:
`tuple`: The input processed data.
"""
@ -554,6 +560,7 @@ class PPOTrainer(BaseTrainer):
queries = [tensor.to(self.current_device) for tensor in queries]
responses = [tensor.to(self.current_device) for tensor in responses]
scores = [tensor.to(self.current_device) for tensor in scores]
masks = [tensor.to(self.current_device) for tensor in masks] if masks is not None else None
# squeeze scores if needed
for i, score in enumerate(scores):
@ -562,7 +569,7 @@ class PPOTrainer(BaseTrainer):
elif score.dim() == 1:
scores[i] = score.squeeze()
return queries, responses, scores
return queries, responses, scores, masks
@PPODecorators.empty_cuda_cache()
def step(
@ -570,6 +577,7 @@ class PPOTrainer(BaseTrainer):
queries: List[torch.LongTensor],
responses: List[torch.LongTensor],
scores: List[torch.FloatTensor],
response_masks: Optional[List[torch.LongTensor]] = None,
):
"""
Run a PPO optimisation step given a list of queries, model responses, and rewards.
@ -581,18 +589,35 @@ class PPOTrainer(BaseTrainer):
List of tensors containing the encoded responses of shape (`response_length`)
scores (List[`torch.FloatTensor`]):
List of tensors containing the scores.
response_masks (List[`torch.FloatTensor`], *optional*)):
List of tensors containing masks of the response tokens.
Returns:
`dict[str, Any]`: A summary of the training statistics
"""
bs = self.config.batch_size
queries, responses, scores = self._step_safety_checker(bs, queries, responses, scores)
queries, responses, scores, response_masks = self._step_safety_checker(
bs, queries, responses, scores, response_masks
)
scores = torch.tensor(scores)
if self.config.use_score_scaling:
# Score scaling
scores_mean, scores_std = self.running.update(scores)
if self.config.use_score_norm:
scores = (scores - self.running.mean) / self.running.std
else:
scores /= self.running.std
if self.config.score_clip is not None:
# Score clipping
scores_dtype = scores.dtype
scores = torch.clip(scores.float(), -self.config.score_clip, self.config.score_clip).to(dtype=scores_dtype)
# if we want to push best model to the hub
if hasattr(self, "highest_reward"):
if self.compare_step % self.config.compare_steps == 0:
curr_mean_reward = torch.tensor(scores).mean()
curr_mean_reward = scores.mean()
# if the best reward ever seen
if curr_mean_reward > self.highest_reward:
self.highest_reward = curr_mean_reward
@ -639,9 +664,13 @@ class PPOTrainer(BaseTrainer):
with torch.no_grad():
all_logprobs, logits_or_none, values, masks = self.batched_forward_pass(
self.model, queries, responses, model_inputs, return_logits=full_kl_penalty
self.model,
queries,
responses,
model_inputs,
response_masks=response_masks,
return_logits=full_kl_penalty,
)
# for when the model is a peft model
if self.is_peft_model and hasattr(
self.accelerator.unwrap_model(self.model).pretrained_model,
@ -864,7 +893,6 @@ class PPOTrainer(BaseTrainer):
input_data["decoder_input_ids"] = decoder_inputs["input_ids"]
input_data["decoder_attention_mask"] = decoder_inputs["attention_mask"]
else:
input_ids = [torch.cat([q, r]) for q, r in zip(queries, responses)]
input_data = self.data_collator(
@ -872,7 +900,6 @@ class PPOTrainer(BaseTrainer):
).to(self.current_device)
input_data.pop("labels", None) # we don't want to compute LM losses
return input_data
@PPODecorators.empty_cuda_cache()
@ -883,6 +910,7 @@ class PPOTrainer(BaseTrainer):
responses: torch.Tensor,
model_inputs: dict,
return_logits: bool = False,
response_masks: Optional[torch.Tensor] = None,
):
"""
Calculate model outputs in multiple batches.
@ -913,6 +941,8 @@ class PPOTrainer(BaseTrainer):
input_kwargs = {key: value[i * fbs : (i + 1) * fbs] for key, value in model_inputs.items()}
query_batch = queries[i * fbs : (i + 1) * fbs]
response_batch = responses[i * fbs : (i + 1) * fbs]
if response_masks is not None:
response_masks_batch = response_masks[i * fbs : (i + 1) * fbs]
logits, _, values = model(**input_kwargs)
if self.is_encoder_decoder:
@ -936,9 +966,15 @@ class PPOTrainer(BaseTrainer):
if attention_mask[j, 0] == 0: # offset left padding
start += attention_mask[j, :].nonzero()[0]
end = start + len(response_batch[j])
if response_masks is not None:
response_masks_batch[j] = torch.cat(
(torch.zeros_like(query_batch[j]), response_masks_batch[j])
)[1:]
masks[j, :start] = 0
masks[j, end:] = 0
if response_masks is not None:
masks[j, start:end] = masks[j, start:end] * response_masks_batch[j][start:end]
if return_logits:
all_logits.append(logits)
@ -1186,8 +1222,8 @@ class PPOTrainer(BaseTrainer):
mean_non_score_reward = masked_mean(
data["non_score_reward"], mask
) # non_score_reward is size `batch_size`, `response_length`
mean_scores = torch.stack(data["scores"]).mean() # scores is size `batch_size`
std_scores = torch.stack(data["scores"]).std()
mean_scores = data["scores"].mean() # scores is size `batch_size`
std_scores = data["scores"].std()
if mean_kl.item() < -1.0:
# warn users
@ -1259,8 +1295,12 @@ class PPOTrainer(BaseTrainer):
elif self.config.log_with == "wandb":
import wandb
table_rows = [list(r) for r in zip(batch["query"], batch["response"], rewards.cpu().tolist())]
logs.update({"game_log": wandb.Table(columns=["query", "response", "reward"], rows=table_rows)})
table_rows = [
list(r) for r in zip(batch["query"], batch["response"], batch["answer"], rewards.cpu().tolist())
]
logs.update(
{"game_log": wandb.Table(columns=["query", "response", "answer", "reward"], rows=table_rows)}
)
# All reduce rewards if distributed
if self.is_distributed:
import torch.distributed as dist
@ -1281,10 +1321,6 @@ class PPOTrainer(BaseTrainer):
logs["env/reward_std"] = torch.std(rewards).cpu().numpy().item()
logs["env/reward_dist"] = rewards.cpu().numpy()
logs["env/reward_mean"] = torch.mean(rewards).cpu().numpy().item()
logs["env/reward_std"] = torch.std(rewards).cpu().numpy().item()
logs["env/reward_dist"] = rewards.cpu().numpy()
if self.config.log_with == "tensorboard":
# update the current step
self.current_step += 1
@ -1329,3 +1365,18 @@ class PPOTrainer(BaseTrainer):
self.accelerator.unwrap_model(self.model).save_pretrained(save_directory)
self.tokenizer.save_pretrained(save_directory)
self.create_model_card(save_directory)
def _show_tokens(self, tokens, masks):
from rich import print
from rich.text import Text
text = Text()
for i, (token, mask) in enumerate(zip(tokens, masks)):
if mask == 1:
text.append(self.tokenizer.decode(token.item()), style="black on deep_sky_blue1")
text.append(" ")
else:
text.append(self.tokenizer.decode(token.item()), style="black on cyan3")
text.append(" ")
print(text)

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from dataclasses import FrozenInstanceError, replace
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
@ -129,7 +130,10 @@ class RewardTrainer(Trainer):
data_collator = RewardDataCollatorWithPadding(tokenizer, max_length=max_length)
if args.remove_unused_columns:
args.remove_unused_columns = False
try: # for bc before https://github.com/huggingface/transformers/pull/25435
args.remove_unused_columns = False
except FrozenInstanceError:
args = replace(args, remove_unused_columns=False)
# warn users
warnings.warn(
"When using RewardDataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments"

View File

@ -85,7 +85,7 @@ class SFTTrainer(Trainer):
The number of sequences to use for the `ConstantLengthDataset`. Defaults to `1024`.
chars_per_token (`Optional[float]`):
The number of characters per token to use for the `ConstantLengthDataset`. Defaults to `3.6`. You can check how this is computed in the
stack-llama example: https://github.com/lvwerra/trl/blob/08f550674c553c36c51d1027613c29f14f3676a5/examples/stack_llama/scripts/supervised_finetuning.py#L53.
stack-llama example: https://github.com/huggingface/trl/blob/08f550674c553c36c51d1027613c29f14f3676a5/examples/stack_llama/scripts/supervised_finetuning.py#L53.
packing (`Optional[bool]`):
Used only in case `dataset_text_field` is passed. This argument is used by the `ConstantLengthDataset` to pack the sequences
of the dataset.

View File

@ -14,8 +14,9 @@
import os
import random
import warnings
from collections import deque
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
@ -61,8 +62,9 @@ class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling):
Args:
instruction_template (`Optional[str]`): the template form that indicates the start of the human instruction, typically something like
'### Human:\n'. Useful for assistant-style conversation datasets
response_template (`str`): the template form that indicates the start of the response, typically something like
'### Response:\n'
response_template (`Union[str, List[int]]`): the template form that indicates the start of the response, typically something like
'### Response:\n'. It can also be passed as tokenized ids, which can be useful when using a tokenizer that encodes the response
differently if it does not have proper context.
mlm (`bool`, *optional*, defaults to `False`): Whether or not to use masked language modeling in the underlying
`DataCollatorForLanguageModeling` class. Note that this option currently has no effect but is present
for flexibility and backwards-compatibility.
@ -72,7 +74,7 @@ class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling):
def __init__(
self,
response_template: str,
response_template: Union[str, List[int]],
instruction_template: Optional[str] = None,
*args,
mlm: bool = False,
@ -83,7 +85,11 @@ class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling):
self.instruction_template = instruction_template
self.response_template = response_template
self.ignore_index = ignore_index
self.response_token_ids = self.tokenizer.encode(self.response_template, add_special_tokens=False)
if type(response_template) == list:
# The user already provides the token ids
self.response_token_ids = response_template
else:
self.response_token_ids = self.tokenizer.encode(self.response_template, add_special_tokens=False)
def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
batch = super().torch_call(examples)
@ -101,14 +107,18 @@ class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling):
response_token_ids_start_idx = idx
if response_token_ids_start_idx is None:
raise RuntimeError(
f'Could not find response key {self.response_token_ids} in token IDs {batch["labels"][i]}'
warnings.warn(
f"Could not find response key `{self.response_template}` in the "
f'following instance: {self.tokenizer.decode(batch["input_ids"][i])} '
f"This instance will be ignored in loss calculation. "
f"Note, if this happens often, consider increasing the `max_seq_length`."
)
batch["labels"][i, :] = self.ignore_index
else:
response_token_ids_end_idx = response_token_ids_start_idx + len(self.response_token_ids)
response_token_ids_end_idx = response_token_ids_start_idx + len(self.response_token_ids)
# Make pytorch loss function ignore all tokens up through the end of the response key
batch["labels"][i, :response_token_ids_end_idx] = self.ignore_index
# Make pytorch loss function ignore all tokens up through the end of the response key
batch["labels"][i, :response_token_ids_end_idx] = self.ignore_index
else:
for i in range(len(examples)):
@ -123,10 +133,14 @@ class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling):
):
response_token_ids_idxs.append(assistant_idx + len(self.response_token_ids))
if len(self.response_token_ids) == 0:
raise RuntimeError(
f'Could not find response key {self.response_token_ids} in token IDs {batch["labels"][i]}'
if len(response_token_ids_idxs) == 0:
warnings.warn(
f"Could not find response key `{self.response_template}` in the "
f'following instance: {self.tokenizer.decode(batch["input_ids"][i])} '
f"This instance will be ignored in loss calculation. "
f"Note, if this happens often, consider increasing the `max_seq_length`."
)
batch["labels"][i, :] = self.ignore_index
human_token_ids = self.tokenizer.encode(self.instruction_template, add_special_tokens=False)
for human_idx in np.where(batch["labels"][i] == human_token_ids[0])[0]:
@ -135,9 +149,13 @@ class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling):
human_token_ids_idxs.append(human_idx)
if len(human_token_ids_idxs) == 0:
raise RuntimeError(
f'Could not find response key {human_token_ids} in token IDs {batch["labels"][i]}'
warnings.warn(
f"Could not find instruction key `{self.instruction_template}` in the "
f'following instance: {self.tokenizer.decode(batch["input_ids"][i])} '
f"This instance will be ignored in loss calculation. "
f"Note, if this happens often, consider increasing the `max_seq_length`."
)
batch["labels"][i, :] = self.ignore_index
for idx, (start, end) in enumerate(zip(human_token_ids_idxs, response_token_ids_idxs)):
# Make pytorch loss function ignore all non response tokens
@ -243,7 +261,7 @@ class DPODataCollatorWithPadding:
padding_value (`int`, defaults to 0):
The value used for padding.
truncation_mode: (`str`, defaults to "keep_end"):
The truncation mode to use when truncating the prompt + chosen/rejected responses.
The truncation mode to use when truncating the prompt.
"""
tokenizer: PreTrainedTokenizerBase
padding: Union[bool, str] = True
@ -499,6 +517,65 @@ class PeftSavingCallback(TrainerCallback):
os.remove(os.path.join(checkpoint_path, "pytorch_model.bin"))
class RunningMoments:
def __init__(self, accelerator):
"""
Calculates the running mean and standard deviation of a data stream. Reference:
https://github.com/OpenLMLab/MOSS-RLHF/blob/40b91eb2f2b71b16919addede0341d2bef70825d/utils.py#L75
"""
self.mean = 0
self.std = 1
self.var = 1
self.count = 1e-24
self.accelerator = accelerator
@torch.no_grad()
def update(self, xs: torch.Tensor) -> Tuple[float, float]:
"""
Updates running moments from batch's moments computed across ranks
"""
if self.accelerator.use_distributed:
xs_mean, xs_var, xs_count = get_global_statistics(self.accelerator, xs)
else:
xs_count = xs.numel()
xs_var, xs_mean = torch.var_mean(xs, unbiased=False)
xs_mean, xs_var = xs_mean.float(), xs_var.float()
delta = xs_mean - self.mean
tot_count = self.count + xs_count
new_sum = xs_var * xs_count
# correct old_sum deviation accounting for the new mean
old_sum = self.var * self.count + delta**2 * self.count * xs_count / tot_count
tot_sum = old_sum + new_sum
self.mean += delta * xs_count / tot_count
self.var = tot_sum / tot_count
self.std = (self.var * tot_count / (tot_count - 1)).float().sqrt()
self.count = tot_count
return xs_mean.item(), (xs_var * xs_count / (xs_count - 1)).float().sqrt().item()
@torch.no_grad()
def get_global_statistics(accelerator, xs: torch.Tensor, mask=None, device="cpu") -> Tuple[float, float, int]:
"""
Computes element-wise mean and variance of the tensor across processes. Reference:
https://github.com/OpenLMLab/MOSS-RLHF/blob/40b91eb2f2b71b16919addede0341d2bef70825d/utils.py#L57C1-L73C75
"""
xs = xs.to(accelerator.device)
sum_and_count = torch.tensor([xs.sum(), (xs.numel() if mask is None else mask.sum())], device=xs.device)
sum_and_count = accelerator.reduce(sum_and_count)
global_sum, count = sum_and_count
global_mean = global_sum / count
sum_var = torch.sum(((xs - global_mean) ** 2).mul(1 if mask is None else mask))
sum_var = accelerator.reduce(sum_var)
global_var = sum_var / count
return global_mean.to(device), global_var.to(device), count.to(device)
def compute_accuracy(eval_pred) -> Dict[str, float]:
predictions, labels = eval_pred
# Here, predictions is rewards_chosen and rewards_rejected.
@ -522,3 +599,58 @@ def pad_to_length(tensor: torch.Tensor, length: int, pad_value: Union[int, float
],
dim=dim,
)
def disable_dropout_in_model(model: torch.nn.Module) -> None:
for module in model.modules():
if isinstance(module, torch.nn.Dropout):
module.p = 0
def exact_div(a, b, a_str, b_str, custom_error_message=""):
q = a // b
if a != q * b:
raise ValueError(f"{custom_error_message}, {a_str}={a}, {b_str}={b}, inexact division: {a} / {b} = {a / b}")
return q
# copied from https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/stat_tracking.py#L5
class PerPromptStatTracker:
r"""
Class for tracking statistics per prompt. Mainly used to calculate advantage for the DPPO algorithm
Args:
buffer_size (`int`):
Size of the buffer to keep for each prompt.
min_count (`int`):
Minimum number of samples to keep in the buffer before calculating the mean and std.
"""
def __init__(self, buffer_size, min_count):
self.buffer_size = buffer_size
self.min_count = min_count
self.stats = {}
def update(self, prompts, rewards):
prompts = np.array(prompts)
rewards = np.array(rewards)
unique = np.unique(prompts)
advantages = np.empty_like(rewards)
for prompt in unique:
prompt_rewards = rewards[prompts == prompt]
if prompt not in self.stats:
self.stats[prompt] = deque(maxlen=self.buffer_size)
self.stats[prompt].extend(prompt_rewards)
if len(self.stats[prompt]) < self.min_count:
mean = np.mean(rewards)
std = np.std(rewards) + 1e-6
else:
mean = np.mean(self.stats[prompt])
std = np.std(self.stats[prompt]) + 1e-6
advantages[prompts == prompt] = (prompt_rewards - mean) / std
return advantages
def get_stats(self):
return {k: {"mean": np.mean(v), "std": np.std(v), "count": len(v)} for k, v in self.stats.items()}