mirror of
https://github.com/huggingface/trl.git
synced 2025-10-21 02:53:59 +08:00
Compare commits
31 Commits
Author | SHA1 | Date | |
---|---|---|---|
5d8ad5d538 | |||
9d09b3e107 | |||
336d63eb80 | |||
7fc970983c | |||
d3bbee3ab8 | |||
eb5465df7e | |||
1c272240ac | |||
b095245830 | |||
c115453fba | |||
16f214c58d | |||
e9a437992e | |||
c837fbe5b9 | |||
01c4a35928 | |||
1aca98fbcf | |||
029f961b7c | |||
8ec912ffa6 | |||
f360c37466 | |||
217313014b | |||
b946e875b1 | |||
6dd50b45d8 | |||
98120d6aeb | |||
3b2c820db6 | |||
25fd6f2313 | |||
3f1477cdc0 | |||
2cff1e4385 | |||
d7d7902938 | |||
77b0cc1707 | |||
17f22c1c20 | |||
e448bb69f0 | |||
9aa4e3ce2b | |||
ca8a508913 |
1
.github/workflows/build_documentation.yml
vendored
1
.github/workflows/build_documentation.yml
vendored
@ -13,7 +13,6 @@ jobs:
|
||||
with:
|
||||
commit_sha: ${{ github.sha }}
|
||||
package: trl
|
||||
repo_owner: lvwerra
|
||||
version_tag_suffix: ""
|
||||
secrets:
|
||||
token: ${{ secrets.HUGGINGFACE_PUSH }}
|
||||
|
1
.github/workflows/build_pr_documentation.yml
vendored
1
.github/workflows/build_pr_documentation.yml
vendored
@ -14,5 +14,4 @@ jobs:
|
||||
commit_sha: ${{ github.event.pull_request.head.sha }}
|
||||
pr_number: ${{ github.event.number }}
|
||||
package: trl
|
||||
repo_owner: lvwerra
|
||||
version_tag_suffix: ""
|
2
.github/workflows/stale.yml
vendored
2
.github/workflows/stale.yml
vendored
@ -7,7 +7,7 @@ on:
|
||||
jobs:
|
||||
close_stale_issues:
|
||||
name: Close Stale Issues
|
||||
if: github.repository == 'lvwerra/trl'
|
||||
if: github.repository == 'huggingface/trl'
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
@ -17,7 +17,7 @@ authors:
|
||||
family-names: Thrush
|
||||
- given-names: Nathan
|
||||
family-names: Lambert
|
||||
repository-code: 'https://github.com/lvwerra/trl'
|
||||
repository-code: 'https://github.com/huggingface/trl'
|
||||
abstract: "With trl you can train transformer language models with Proximal Policy Optimization (PPO). The library is built on top of the transformers library by \U0001F917 Hugging Face. Therefore, pre-trained language models can be directly loaded via transformers. At this point, most decoder and encoder-decoder architectures are supported."
|
||||
keywords:
|
||||
- rlhf
|
||||
|
30
README.md
30
README.md
@ -6,14 +6,14 @@
|
||||
> Full stack transformer language models with reinforcement learning.
|
||||
|
||||
<p align="center">
|
||||
<a href="https://github.com/lvwerra/trl/blob/main/LICENSE">
|
||||
<img alt="License" src="https://img.shields.io/github/license/lvwerra/trl.svg?color=blue">
|
||||
<a href="https://github.com/huggingface/trl/blob/main/LICENSE">
|
||||
<img alt="License" src="https://img.shields.io/github/license/huggingface/trl.svg?color=blue">
|
||||
</a>
|
||||
<a href="https://huggingface.co/docs/trl/index">
|
||||
<img alt="Documentation" src="https://img.shields.io/website/http/huggingface.co/docs/trl/index.svg?down_color=red&down_message=offline&up_message=online">
|
||||
</a>
|
||||
<a href="https://github.com/lvwerra/trl/releases">
|
||||
<img alt="GitHub release" src="https://img.shields.io/github/release/lvwerra/trl.svg">
|
||||
<a href="https://github.com/huggingface/trl/releases">
|
||||
<img alt="GitHub release" src="https://img.shields.io/github/release/huggingface/trl.svg">
|
||||
</a>
|
||||
</p>
|
||||
|
||||
@ -24,7 +24,7 @@
|
||||
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/TRL-readme.png">
|
||||
</div>
|
||||
|
||||
`trl` is a full stack library where we provide a set of tools to train transformer language models with Reinforcement Learning, from the Supervised Fine-tuning step (SFT), Reward Modeling step (RM) to the Proximal Policy Optimization (PPO) step. The library is built on top of the [`transformers`](https://github.com/huggingface/transformers) library by 🤗 Hugging Face. Therefore, pre-trained language models can be directly loaded via `transformers`. At this point most of decoder architectures and encoder-decoder architectures are supported. Refer to the documentation or the `examples/` folder for example code snippets and how to run these tools.
|
||||
`trl` is a full stack library where we provide a set of tools to train transformer language models and stable diffusion models with Reinforcement Learning, from the Supervised Fine-tuning step (SFT), Reward Modeling step (RM) to the Proximal Policy Optimization (PPO) step. The library is built on top of the [`transformers`](https://github.com/huggingface/transformers) library by 🤗 Hugging Face. Therefore, pre-trained language models can be directly loaded via `transformers`. At this point most of decoder architectures and encoder-decoder architectures are supported. Refer to the documentation or the `examples/` folder for example code snippets and how to run these tools.
|
||||
|
||||
**Highlights:**
|
||||
|
||||
@ -32,7 +32,7 @@
|
||||
- [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer): A light wrapper around `transformers` Trainer to easily fine-tune language models for human preferences (Reward Modeling).
|
||||
- [`PPOTrainer`](https://huggingface.co/docs/trl/trainer#trl.PPOTrainer): A PPO trainer for language models that just needs (query, response, reward) triplets to optimise the language model.
|
||||
- [`AutoModelForCausalLMWithValueHead`](https://huggingface.co/docs/trl/models#trl.AutoModelForCausalLMWithValueHead) & [`AutoModelForSeq2SeqLMWithValueHead`](https://huggingface.co/docs/trl/models#trl.AutoModelForSeq2SeqLMWithValueHead): A transformer model with an additional scalar output for each token which can be used as a value function in reinforcement learning.
|
||||
- [Examples](https://github.com/lvwerra/trl/tree/main/examples): Train GPT2 to generate positive movie reviews with a BERT sentiment classifier, full RLHF using adapters only, train GPT-j to be less toxic, [Stack-Llama example](https://huggingface.co/blog/stackllama), etc.
|
||||
- [Examples](https://github.com/huggingface/trl/tree/main/examples): Train GPT2 to generate positive movie reviews with a BERT sentiment classifier, full RLHF using adapters only, train GPT-j to be less toxic, [Stack-Llama example](https://huggingface.co/blog/stackllama), etc.
|
||||
|
||||
## How PPO works
|
||||
Fine-tuning a language model via PPO consists of roughly three steps:
|
||||
@ -60,7 +60,7 @@ pip install trl
|
||||
### From source
|
||||
If you want to run the examples in the repository a few additional libraries are required. Clone the repository and install it with pip:
|
||||
```bash
|
||||
git clone https://github.com/lvwerra/trl.git
|
||||
git clone https://github.com/huggingface/trl.git
|
||||
cd trl/
|
||||
pip install .
|
||||
```
|
||||
@ -106,7 +106,7 @@ from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
from trl import RewardTrainer
|
||||
|
||||
# load model and dataset - dataset needs to be in a specific format
|
||||
model = AutoModelForSequenceClassification.from_pretrained("gpt2")
|
||||
model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=1)
|
||||
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
|
||||
...
|
||||
@ -162,16 +162,6 @@ reward = [torch.tensor(1.0)]
|
||||
train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward)
|
||||
```
|
||||
|
||||
### Advanced example: IMDB sentiment
|
||||
For a detailed example check out the example python script `examples/scripts/sentiment_tuning.py`, where GPT2 is fine-tuned to generate positive movie reviews. An few examples from the language models before and after optimisation are given below:
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/table_imdb_preview.png" width="800">
|
||||
<p style="text-align: center;"> <b>Figure:</b> A few review continuations before and after optimisation. </p>
|
||||
</div>
|
||||
|
||||
Have a look at more examples inside [`examples/`](https://github.com/lvwerra/trl/tree/main/examples) folder.
|
||||
|
||||
## References
|
||||
|
||||
### Proximal Policy Optimisation
|
||||
@ -184,11 +174,11 @@ The language models utilize the `transformers` library by 🤗 Hugging Face.
|
||||
|
||||
```bibtex
|
||||
@misc{vonwerra2022trl,
|
||||
author = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert},
|
||||
author = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang},
|
||||
title = {TRL: Transformer Reinforcement Learning},
|
||||
year = {2020},
|
||||
publisher = {GitHub},
|
||||
journal = {GitHub repository},
|
||||
howpublished = {\url{https://github.com/lvwerra/trl}}
|
||||
howpublished = {\url{https://github.com/huggingface/trl}}
|
||||
}
|
||||
```
|
||||
|
@ -5,8 +5,12 @@
|
||||
title: Quickstart
|
||||
- local: installation
|
||||
title: Installation
|
||||
- local: how_to_train
|
||||
title: PPO Training FAQ
|
||||
- local: use_model
|
||||
title: Use Trained Models
|
||||
- local: customization
|
||||
title: Customize your Training
|
||||
title: Customize the Training
|
||||
- local: logging
|
||||
title: Understanding Logs
|
||||
title: Get started
|
||||
@ -23,6 +27,10 @@
|
||||
title: Best of N Sampling
|
||||
- local: dpo_trainer
|
||||
title: DPO Trainer
|
||||
- local: ddpo_trainer
|
||||
title: Denoising Diffusion Policy Optimization
|
||||
- local: text_environments
|
||||
title: Text Environments
|
||||
title: API
|
||||
- sections:
|
||||
- local: sentiment_tuning
|
||||
@ -33,6 +41,8 @@
|
||||
title: Detoxifying a Language Model
|
||||
- local: using_llama_models
|
||||
title: Training StackLlama
|
||||
- local: learning_tools
|
||||
title: Learning to Use Tools
|
||||
- local: multi_adapter_rl
|
||||
title: Multi Adapter RLHF
|
||||
title: Examples
|
||||
|
@ -16,7 +16,7 @@ Then make sure you have selected multi-gpu / multi-node setup. You can then run
|
||||
accelerate launch your_script.py
|
||||
```
|
||||
|
||||
Refer to the [examples page](https://github.com/lvwerra/trl/tree/main/examples) for more details
|
||||
Refer to the [examples page](https://github.com/huggingface/trl/tree/main/examples) for more details
|
||||
|
||||
## Use different optimizers
|
||||
|
||||
@ -180,18 +180,21 @@ else:
|
||||
sentiment_pipe = pipeline("sentiment-analysis", model="lvwerra/distilbert-imdb", device=device)
|
||||
```
|
||||
|
||||
## Use torch distributed
|
||||
torch.distributed package provides PyTorch natives method to distribute a network over several machines (mostly useful if there are several GPU nodes). It copies the model on each GPU, runs the forward and backward on each and then applies the mean of gradient of all GPUs for each one. If running torch 1.XX, you can call `torch.distributed.launch`, like
|
||||
|
||||
`python -m torch.distributed.launch --nproc_per_node=1 reward_summarization.py --bf16`
|
||||
## Use score scaling/normalization/clipping
|
||||
As suggested by [Secrets of RLHF in Large Language Models Part I: PPO](https://arxiv.org/abs/2307.04964), we support score (aka reward) scaling/normalization/clipping to improve training stability via `PPOConfig`:
|
||||
```python
|
||||
from trl import PPOConfig
|
||||
|
||||
For torch 2.+ `torch.distributed.launch` is deprecated and one needs to run:
|
||||
`torchrun --nproc_per_node=1 reward_summarization.py --bf16`
|
||||
or
|
||||
`python -m torch.distributed.run --nproc_per_node=1 reward_summarization.py --bf16`
|
||||
|
||||
Note that using `python -m torch.distributed.launch --nproc_per_node=1 reward_summarization.py --bf16` with torch 2.0 ends in
|
||||
ppo_config = {
|
||||
use_score_scaling=True,
|
||||
use_score_norm=True,
|
||||
score_clip=0.5,
|
||||
}
|
||||
config = PPOConfig(**ppo_config)
|
||||
```
|
||||
ValueError: Some specified arguments are not used by the HfArgumentParser: ['--local-rank=0']
|
||||
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 194889) of binary: /home/ubuntu/miniconda3/envs/trl/bin/python
|
||||
|
||||
To run `sentiment_tuning.py`, you can use the following command:
|
||||
```
|
||||
python examples/scripts/sentiment_tuning.py --log_with wandb --use_score_scaling --use_score_norm --score_clip 0.5
|
||||
```
|
||||
|
144
docs/source/ddpo_trainer.mdx
Normal file
144
docs/source/ddpo_trainer.mdx
Normal file
@ -0,0 +1,144 @@
|
||||
# Denoising Diffusion Policy Optimization
|
||||
## The why
|
||||
|
||||
| Before | After DDPO finetuning |
|
||||
| --- | --- |
|
||||
| <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/pre_squirrel.png"/></div> | <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/post_squirrel.png"/></div> |
|
||||
| <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/pre_crab.png"/></div> | <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/post_crab.png"/></div> |
|
||||
| <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/pre_starfish.png"/></div> | <div style="text-align: center"><img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/post_starfish.png"/></div> |
|
||||
|
||||
|
||||
## Getting started with Stable Diffusion finetuning with reinforcement learning
|
||||
|
||||
The machinery for finetuning of Stable Diffusion models with reinforcement learning makes heavy use of HuggingFace's `diffusers`
|
||||
library. A reason for stating this is that getting started requires a bit of familiarity with the `diffusers` library concepts, mainly two of them - pipelines and schedulers.
|
||||
Right out of the box (`diffusers` library), there isn't a `Pipeline` nor a `Scheduler` instance that is suitable for finetuning with reinforcement learning. Some adjustments need to made.
|
||||
|
||||
There is a pipeline interface that is provided by this library that is required to be implemented to be used with the `DDPOTrainer`, which is the main machinery for fine-tuning Stable Diffusion with reinforcement learning. **Note: Only the StableDiffusion architecture is supported at this point.**
|
||||
There is a default implementation of this interface that you can use out of the box. Assuming the default implementation is sufficient and/or to get things moving, refer to the training example alongside this guide.
|
||||
|
||||
The point of the interface is to fuse the pipeline and the scheduler into one object which allows for minimalness in terms of having the constraints all in one place. The interface was designed in hopes of catering to pipelines and schedulers beyond the examples in this repository and elsewhere at this time of writing. Also the scheduler step is a method of this pipeline interface and this may seem redundant given that the raw scheduler is accessible via the interface but this is the only way to constrain the scheduler step output to an output type befitting of the algorithm at hand (DDPO).
|
||||
|
||||
For a more detailed look into the interface and the associated default implementation, go [here](https://github.com/lvwerra/trl/tree/main/trl/models/modeling_sd_base.py)
|
||||
|
||||
Note that the default implementation has a LoRA implementation path and a non-LoRA based implementation path. The LoRA flag enabled by default and this can be turned off by passing in the flag to do so. LORA based training is faster and the LORA associated model hyperparameters responsible for model convergence aren't as finicky as non-LORA based training.
|
||||
|
||||
Also in addition, there is the expectation of providing a reward function and a prompt function. The reward function is used to evaluate the generated images and the prompt function is used to generate the prompts that are used to generate the images.
|
||||
|
||||
## Getting started with `examples/scripts/stable_diffusion_tuning.py`
|
||||
|
||||
The `stable_diffusion_tuning.py` script is a working example of using the `DDPO` trainer to finetune a Stable Diffusion model. This example explicitly configures a small subset of the overall parameters associated with the config object (`DDPOConfig`).
|
||||
|
||||
**Note:** one A100 GPU is recommended to get this running. Anything below a A100 will not be able to run this example script and even if it does via relatively smaller sized parameters, the results will most likely be poor.
|
||||
|
||||
Almost every configuration parameter has a default. There is only one commandline flag argument that is required of the user to get things up and running. The user is expected to have a [huggingface user access token](https://huggingface.co/docs/hub/security-tokens) that will be used to upload the model post finetuning to HuggingFace hub. The following bash command is to be entered to get things running
|
||||
|
||||
```batch
|
||||
python stable_diffusion_tuning.py --hf_user_access_token <token>
|
||||
```
|
||||
|
||||
Again, the script uses a small subset of parameters to configure the trainer. And all of these are configurable via the commandline.
|
||||
It should be noted (in general) that because the trainer uses `accelerate` as a core component, some parameters are those of accelerate's.
|
||||
The commandline flags that are associated with the example script's parameters are listed below.
|
||||
|
||||
|parameter|description|default|
|
||||
| ---- | ---- | ---- |
|
||||
|`--hf_hub_aesthetic_model_id`|The HuggingFace model hub id of the aesthetic scorer model|`"trl-lib/ddpo-aesthetic-predictor"`|
|
||||
|`--hf_hub_aesthetic_model_filename`|The filename of the aesthetic scorer model |`"aesthetic-model.pth"`|
|
||||
|`--pretrained_model`|The string id of the pretrained Stable Diffusion model|`"runwayml/stable-diffusion-v1-5"`|
|
||||
|`--pretrained_revision`|The revision of the pretrained Stable Diffusion model|`"main"`|
|
||||
|`--num_epochs`|The number of epochs to train for|`200`|
|
||||
|`--train_batch_size`|The batch size to use for training|`3`|
|
||||
|`--sample_batch_size`|The batch size to use for sampling|`6`|
|
||||
|`--gradient_accumulation_steps`|The number of accelerator based gradient accumulation steps to use|`1`|
|
||||
|`--sample_num_steps`| The number of steps to sample for|`50`|
|
||||
|`--sample_num_batches_per_epoch`|The number of batches to sample per epoch|`4`|
|
||||
|`--log_with`|The logger to use. Either `wandb` or `tensorboard`|`wandb`|
|
||||
|`--per_prompt_stat_tracking`|Whether to track stats per prompt. If false, advantages will be calculated using the mean and std of the entire batch as opposed to tracking per prompt|`True`|
|
||||
|`--per_prompt_stat_tracking_buffer_size`|The size of the buffer to use for tracking stats per prompt|`32`|
|
||||
|`--tracker_project_name`|The name of the project for use on the tracking platform (wandb/tensorboard/etc) |`"stable_diffusion_training"`|
|
||||
| `--logging_dir`|The directory to use for logging|`"logs"`|
|
||||
| `--project_dir`|The directory to use for saving the model|`"save"`|
|
||||
| `--automatic_checkpoint_naming`|Whether to automatically name model checkpoints|`True`|
|
||||
| `--total_limit`| Number of checkpoints to keep before overwriting old ones|`5`|
|
||||
| `--hf_hub_model_id`|The HuggingFace model hub id to use for saving the model|`"ddpo-finetuned-sd-model"`|
|
||||
| `--hf_user_access_token`| The HuggingFace user access token|`None`|
|
||||
|
||||
The following are things to keep in mind (The code checks this for you as well) in general while configuring the trainer (beyond the use case of using the example script)
|
||||
|
||||
- The configurable sample batch size should be greater than or equal to the configurable training batch size
|
||||
- The configurable sample batch size must be divisible by the configurable train batch size
|
||||
- The configurable sample batch size must be divisible by both the configurable gradient accumulation steps and the configurable accelerator processes count
|
||||
|
||||
## Setting up the image logging hook function
|
||||
|
||||
Expect the function to be given a list of lists of the form
|
||||
```python
|
||||
[[image, prompt, prompt_metadata, rewards, reward_metadata], ...]
|
||||
|
||||
```
|
||||
and `image`, `prompt`, `prompt_metadata`, `rewards`, `reward_metadata` are batched.
|
||||
The last list in the lists of lists represents the last sample batch. You are likely to want to log this one
|
||||
While you are free to log however you want the use of `wandb` or `tensorboard` is recommended.
|
||||
|
||||
### Key terms
|
||||
|
||||
- `rewards` : The rewards/score is a numerical associated with the generated image and is key to steering the RL process
|
||||
- `reward_metadata` : The reward metadata is the metadata associated with the reward. Think of this as extra information payload delivered alongside the reward
|
||||
- `prompt` : The prompt is the text that is used to generate the image
|
||||
- `prompt_metadata` : The prompt metadata is the metadata associated with the prompt. A situation where this will not be empty is when the reward model comprises of a [`FLAVA`](https://huggingface.co/docs/transformers/model_doc/flava) setup where questions and ground answers (linked to the generated image) are expected with the generated image (See here: https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/rewards.py#L45)
|
||||
- `image` : The image generated by the Stable Diffusion model
|
||||
|
||||
Example code for logging sampled images with `wandb` is given below.
|
||||
|
||||
```python
|
||||
# for logging these images to wandb
|
||||
|
||||
def image_outputs_hook(image_data, global_step, accelerate_logger):
|
||||
# For the sake of this example, we only care about the last batch
|
||||
# hence we extract the last element of the list
|
||||
result = {}
|
||||
images, prompts, _, rewards, _ = image_data[-1]
|
||||
for i, image in enumerate(images):
|
||||
pil = Image.fromarray(
|
||||
(image.cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8)
|
||||
)
|
||||
pil = pil.resize((256, 256))
|
||||
result[f"{prompts[i]:.25} | {rewards[i]:.2f}"] = [pil]
|
||||
accelerate_logger.log_images(
|
||||
result,
|
||||
step=global_step,
|
||||
)
|
||||
|
||||
```
|
||||
|
||||
### Using the finetuned model
|
||||
|
||||
Assuming you've done with all the epochs and have pushed up your model to the hub, you can use the finetuned model as follows
|
||||
|
||||
```python
|
||||
|
||||
import torch
|
||||
from trl import DefaultDDPOStableDiffusionPipeline
|
||||
|
||||
pipeline = DefaultDDPOStableDiffusionPipeline("metric-space/ddpo-finetuned-sd-model")
|
||||
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
|
||||
# memory optimization
|
||||
pipeline.vae.to(device, torch.float16)
|
||||
pipeline.text_encoder.to(device, torch.float16)
|
||||
pipeline.unet.to(device, torch.float16)
|
||||
|
||||
prompts = ["squirrel", "crab", "starfish", "whale","sponge", "plankton"]
|
||||
results = pipeline(prompts)
|
||||
|
||||
for prompt, image in zip(prompts,results.images):
|
||||
image.save(f"{prompt}.png")
|
||||
|
||||
```
|
||||
|
||||
## Credits
|
||||
|
||||
This work is heavily influenced by the repo [here](https://github.com/kvablack/ddpo-pytorch) and the associated paper [Training Diffusion Models
|
||||
with Reinforcement Learning by Kevin Black, Michael Janner, Yilan Du, Ilya Kostrikov, Sergey Levine](https://arxiv.org/abs/2305.13301).
|
@ -4,12 +4,12 @@ Language models (LMs) are known to sometimes generate toxic outputs. In this exa
|
||||
|
||||
Read this section to follow our investigation on how we can reduce toxicity in a wide range of LMs, from 125m parameters to 6B parameters!
|
||||
|
||||
Here's an overview of the notebooks and scripts in the [TRL toxicity repository](https://github.com/lvwerra/trl/tree/main/examples/toxicity/scripts) as well as the link for the interactive demo:
|
||||
Here's an overview of the notebooks and scripts in the [TRL toxicity repository](https://github.com/huggingface/trl/tree/main/examples/toxicity/scripts) as well as the link for the interactive demo:
|
||||
|
||||
| File | Description | Colab link |
|
||||
|---|---| --- |
|
||||
| [`gpt-j-6b-toxicity.py`](https://github.com/lvwerra/trl/blob/main/examples/research_projects/toxicity/scripts/gpt-j-6b-toxicity.py) | Detoxify `GPT-J-6B` using PPO | x |
|
||||
| [`evaluate-toxicity.py`](https://github.com/lvwerra/trl/blob/main/examples/research_projects/toxicity/scripts/evaluate-toxicity.py) | Evaluate de-toxified models using `evaluate` | x |
|
||||
| [`gpt-j-6b-toxicity.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/toxicity/scripts/gpt-j-6b-toxicity.py) | Detoxify `GPT-J-6B` using PPO | x |
|
||||
| [`evaluate-toxicity.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/toxicity/scripts/evaluate-toxicity.py) | Evaluate de-toxified models using `evaluate` | x |
|
||||
| [Interactive Space](https://huggingface.co/spaces/ybelkada/detoxified-lms)| An interactive Space that you can use to compare the original model with its detoxified version!| x |
|
||||
|
||||
## Context
|
||||
@ -174,7 +174,7 @@ Below are few generation examples of `gpt-j-6b-detox` model:
|
||||
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl-toxicity-examples.png">
|
||||
</div>
|
||||
|
||||
The evaluation script can be found [here](https://github.com/lvwerra/trl/blob/main/examples/research_projects/toxicity/scripts/evaluate-toxicity.py).
|
||||
The evaluation script can be found [here](https://github.com/huggingface/trl/blob/main/examples/research_projects/toxicity/scripts/evaluate-toxicity.py).
|
||||
|
||||
### Discussions
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
# DPO Trainer
|
||||
|
||||
TRL supports the DPO Trainer for training language models from preference data, as described in the paper [Direct Preference Optimization: Your Language Model is Secretly a Reward Model](https://arxiv.org/abs/2305.18290) by Rafailov et al., 2023. For a full example have a look at [`examples/dpo.py`](https://github.com/lvwerra/trl/blob/main/examples/dpo.py).
|
||||
TRL supports the DPO Trainer for training language models from preference data, as described in the paper [Direct Preference Optimization: Your Language Model is Secretly a Reward Model](https://arxiv.org/abs/2305.18290) by Rafailov et al., 2023. For a full example have a look at [`examples/dpo.py`](https://github.com/huggingface/trl/blob/main/examples/dpo.py).
|
||||
|
||||
|
||||
The first step as always is to train your SFT model, to ensure the data we train on is in-distribution for the DPO algorithm.
|
||||
@ -77,6 +77,15 @@ dpo_trainer.train()
|
||||
|
||||
Note that the `beta` is the temperature parameter for the DPO loss, typically something in the range of `0.1` to `0.5`. We ignore the reference model as `beta` -> 0.
|
||||
|
||||
## Logging
|
||||
|
||||
While training and evaluating we record the following reward metrics:
|
||||
|
||||
* `rewards/chosen`: the mean difference between the log probabilities of the policy model and the reference model for the chosen responses scaled by beta
|
||||
* `rewards/rejected`: the mean difference between the log probabilities of the policy model and the reference model for the rejected responses scaled by beta
|
||||
* `rewards/accuracies`: mean of how often the chosen rewards are > than the corresponding rejected rewards
|
||||
* `rewards/margins`: the mean difference between the chosen and corresponding rejected rewards
|
||||
|
||||
## DPOTrainer
|
||||
|
||||
[[autodoc]] DPOTrainer
|
66
docs/source/how_to_train.md
Normal file
66
docs/source/how_to_train.md
Normal file
@ -0,0 +1,66 @@
|
||||
# Training FAQ
|
||||
|
||||
## What Metrics Should I Look at?
|
||||
|
||||
When performing classical supervised fine-tuning of language models, the loss (especially the validation loss) serves as a good indicator of the training progress. However, in Reinforcement Learning (RL), the loss becomes less informative about the model's performance, and its value may fluctuate while the actual performance improves.
|
||||
|
||||
To address this, we recommend focusing on two key metrics first:
|
||||
|
||||
**Mean Reward**: The primary goal is to maximize the reward achieved by the model during RL training.
|
||||
**Objective KL Divergence**: KL divergence (Kullback-Leibler divergence) measures the dissimilarity between two probability distributions. In the context of RL training, we use it to quantify the difference between the current model and a reference model. Ideally, we want to keep the KL divergence between 0 and 10 to ensure the model's generated text remains close to what the reference model produces.
|
||||
|
||||
However, there are more metrics that can be useful for debugging, checkout the [logging section](logging).
|
||||
|
||||
## Why Do We Use a Reference Model, and What's the Purpose of KL Divergence?
|
||||
|
||||
When training RL models, optimizing solely for reward may lead to unexpected behaviors, where the model exploits the environment in ways that don't align with good language generation. In the case of RLHF, we use a reward model trained to predict whether a generated text is highly ranked by humans.
|
||||
|
||||
However, the RL model being optimized against the reward model may learn patterns that yield high reward but do not represent good language. This can result in extreme cases where the model generates texts with excessive exclamation marks or emojis to maximize the reward. In some worst-case scenarios, the model may generate patterns completely unrelated to natural language yet receive high rewards, similar to adversarial attacks.
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/kl-example.png">
|
||||
<p style="text-align: center;"> <b>Figure:</b> Samples without a KL penalty from <a href="https://arxiv.org/pdf/1909.08593.pdf">https://arxiv.org/pdf/1909.08593.pdf</a>. </p>
|
||||
</div>
|
||||
|
||||
To address this issue, we add a penalty to the reward function based on the KL divergence between the current model and the reference model. By doing this, we encourage the model to stay close to what the reference model generates.
|
||||
|
||||
## What Is the Concern with Negative KL Divergence?
|
||||
|
||||
If you generate text by purely sampling from the model distribution things work fine in general. But when you use the `generate` method there are a few caveats because it does not always purely sample depending on the settings which can cause KL-divergence to go negative. Essentially when the active model achieves `log_p_token_active < log_p_token_ref` we get negative KL-div. This can happen in a several cases:
|
||||
|
||||
- **top-k sampling**: the model can smooth out the probability distribution causing the top-k tokens having a smaller probability than those of the reference model but they still are selected
|
||||
- **min_length**: this ignores the EOS token until `min_length` is reached. thus the model can assign a very high log prob to the EOS token and very low prob to all others until min_length is reached
|
||||
- **batched generation**: finished sequences in a batch are padded until all generations are finished. The model can learn to assign very low probabilities to the padding tokens unless they are properly masked or removed.
|
||||
|
||||
These are just a few examples. Why is negative KL an issue? The total reward `R` is computed `R = r - beta * KL` so if the model can learn how to drive KL-divergence negative it effectively gets a positive reward. In many cases it can be much easier to exploit such a bug in the generation than actually learning the reward function. In addition the KL can become arbitrarily small thus the actual reward can be very small compared to it.
|
||||
|
||||
So how should you generate text for PPO training? Let's have a look!
|
||||
|
||||
## How to generate text for training?
|
||||
|
||||
In order to avoid the KL issues described above we recommend to use the following settings:
|
||||
|
||||
```python
|
||||
generation_kwargs = {
|
||||
"min_length": -1, # don't ignore the EOS token (see above)
|
||||
"top_k": 0.0, # no top-k sampling
|
||||
"top_p": 1.0, # no nucleus sampling
|
||||
"do_sample": True, # yes, we want to sample
|
||||
"pad_token_id": tokenizer.eos_token_id, # most decoder models don't have a padding token - use EOS token instead
|
||||
"max_new_tokens": 32, # specify how many tokens you want to generate at most
|
||||
}
|
||||
```
|
||||
|
||||
With these settings we usually don't encounter any issues. You can also experiments with other settings but if you encounter issues with negative KL-divergence try to go back to these and see if they persist.
|
||||
|
||||
## How can debug your own use-case?
|
||||
|
||||
Debugging the RL pipeline can be challenging due to its complexity. Here are some tips and suggestions to make the process easier:
|
||||
|
||||
- **Start from a working example**: Begin with a working example from the trl repository and gradually modify it to fit your specific use-case. Changing everything at once can make it difficult to identify the source of potential issues. For example, you can start by replacing the model in the example and once you figure out the best hyperparameters try to switch to your dataset and reward model. If you change everything at once you won't know where a potential problem comes from.
|
||||
- **Start small, scale later**: Training large models can be very slow and take several hours or days until you see any improvement. For debugging this is not a convenient timescale so try to use small model variants during the development phase and scale up once that works. That being said you sometimes have to be careful as small models might not have the capacity to solve a complicated task either.
|
||||
- **Start simple**: Try to start with a minimal example and build complexity from there. Your use-case might require for example a complicated reward function consisting of many different rewards - try to use one signal first and see if you can optimize that and then add more complexity after that.
|
||||
- **Inspect the generations**: It's always a good idea to inspect what the model is generating. Maybe there is a big in your post-processing or your prompt. Due to bad settings you might cut-off generations too soon. These things are very hard to see on the metrics but very obvious if you look at the generations.
|
||||
- **Inspect the reward model**: If you reward is not improving over time maybe there's an issue with the reward model. You can look at extreme cases to see if it does what it should: e.g. in the sentiment case you can check if simple positive and negative examples really get different rewards. And you can look at the distribution of your dataset. Finally, maybe the reward is dominated by the query which the model can't affect so you might need to normalize this (e.g. reward of query+response minus reward of the query).
|
||||
|
||||
These are just a few tips that we find helpful - if you have more useful tricks feel free to open a PR to add them as well!
|
@ -13,19 +13,45 @@ The library is integrated with 🤗 [transformers](https://github.com/huggingfac
|
||||
|
||||
Check the appropriate sections of the documentation depending on your needs:
|
||||
|
||||
API documentation:
|
||||
## API documentation
|
||||
|
||||
- [Model Classes](models): *A brief overview of what each public model class does.*
|
||||
- [`SFTTrainer`](sft_trainer): *Supervise Fine-tune your model easily with `SFTTrainer`*
|
||||
- [`RewardTrainer`](reward_trainer): *Train easily your reward model using `RewardTrainer`.*
|
||||
- [`PPOTrainer`](trainer): *Further fine-tune the supervised fine-tuned model using PPO algorithm*
|
||||
- [Best-of-N Samppling](best-of-n): *Use best of n sampling as an alternative way to sample predictions from your active model*
|
||||
- [Best-of-N Sampling](best-of-n): *Use best of n sampling as an alternative way to sample predictions from your active model*
|
||||
- [`DPOTrainer`](trainer): *Direct Preference Optimization training using `DPOTrainer`.*
|
||||
- [`TextEnvironment`](text_environment): *Text environment to train your model using tools with RL.*
|
||||
|
||||
Examples:
|
||||
## Examples
|
||||
|
||||
- [Sentiment Tuning](sentiment_tuning): *Fine tune your model to generate positive movie contents*
|
||||
- [Training with PEFT](lora_tuning_peft): *Memory efficient RLHF training using adapters with PEFT*
|
||||
- [Detoxifying LLMs](detoxifying_a_lm): *Detoxify your language model through RLHF*
|
||||
- [StackLlama](using_llama_models): *End-to-end RLHF training of a Llama model on Stack exchange dataset*
|
||||
- [Learning with Tools](learning_tools): *Walkthrough of using `TextEnvironments`*
|
||||
- [Multi-Adapter Training](multi_adapter_rl): *Use a single base model and multiple adapters for memory efficient end-to-end training*
|
||||
|
||||
|
||||
## Blog posts
|
||||
|
||||
<div class="mt-10">
|
||||
<div class="w-full flex flex-col space-y-4 md:space-y-0 md:grid md:grid-cols-2 md:gap-y-4 md:gap-x-5">
|
||||
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/rlhf">
|
||||
<img src="https://raw.githubusercontent.com/huggingface/blog/main/assets/120_rlhf/thumbnail.png" alt="thumbnail">
|
||||
<p class="text-gray-700">Illustrating Reinforcement Learning from Human Feedback</p>
|
||||
</a>
|
||||
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/trl-peft">
|
||||
<img src="https://github.com/huggingface/blog/blob/main/assets/133_trl_peft/thumbnail.png?raw=true" alt="thumbnail">
|
||||
<p class="text-gray-700">Fine-tuning 20B LLMs with RLHF on a 24GB consumer GPU</p>
|
||||
</a>
|
||||
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/stackllama">
|
||||
<img src="https://github.com/huggingface/blog/blob/main/assets/138_stackllama/thumbnail.png?raw=true" alt="thumbnail">
|
||||
<p class="text-gray-700">StackLLaMA: A hands-on guide to train LLaMA with RLHF</p>
|
||||
</a>
|
||||
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="https://huggingface.co/blog/dpo-trl">
|
||||
<img src="https://github.com/huggingface/blog/blob/main/assets/157_dpo_trl/dpo_thumbnail.png?raw=true" alt="thumbnail">
|
||||
<p class="text-gray-700">Fine-tune Llama 2 with DPO</p>
|
||||
</a>
|
||||
</div>
|
||||
</div>
|
||||
|
@ -12,7 +12,7 @@ pip install trl
|
||||
You can also install the latest version from source. First clone the repo and then run the installation with `pip`:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/lvwerra/trl.git
|
||||
git clone https://github.com/huggingface/trl.git
|
||||
cd trl/
|
||||
pip install -e .
|
||||
```
|
||||
|
229
docs/source/learning_tools.mdx
Normal file
229
docs/source/learning_tools.mdx
Normal file
@ -0,0 +1,229 @@
|
||||
# Learning Tools (Experimental 🧪)
|
||||
|
||||
Using Large Language Models (LLMs) with tools has been a popular topic recently with awesome works such as [ToolFormer](https://arxiv.org/abs/2302.04761) and [ToolBench](https://arxiv.org/pdf/2305.16504.pdf). In TRL, we provide a simple example of how to teach LLM to use tools with reinforcement learning.
|
||||
|
||||
|
||||
Here's an overview of the scripts in the [trl repository](https://github.com/lvwerra/trl/tree/main/examples/research_projects/tools):
|
||||
|
||||
| File | Description |
|
||||
|---|---|
|
||||
| [`calculator.py`](https://github.com/lvwerra/trl/blob/main/examples/research_projects/tools/calculator.py) | Script to train LLM to use a calculator with reinforcement learning. |
|
||||
| [`triviaqa.py`](https://github.com/lvwerra/trl/blob/main/examples/research_projects/tools/triviaqa.py) | Script to train LLM to use a wiki tool to answer questions. |
|
||||
| [`python_interpreter.py`](https://github.com/lvwerra/trl/blob/main/examples/research_projects/tools/python_interpreter.py) | Script to train LLM to use python interpreter to solve math puzzles. |
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Note that the scripts above rely heavily on the `TextEnvironment` API which is still under active development. The API may change in the future. Please see [`TextEnvironment`](text_environment) for the related docs.
|
||||
</Tip>
|
||||
|
||||
|
||||
## Learning to Use a Calculator
|
||||
|
||||
|
||||
The rough idea is as follows:
|
||||
|
||||
1. Load a tool such as [ybelkada/simple-calculator](https://huggingface.co/spaces/ybelkada/simple-calculator) that parse a text calculation like `"14 + 34"` and return the calulated number:
|
||||
```python
|
||||
from transformers import AutoTokenizer, load_tool
|
||||
tool = load_tool("ybelkada/simple-calculator")
|
||||
tool_fn = lambda text: str(round(float(tool(text)), 2)) # rounding to 2 decimal places
|
||||
```
|
||||
1. Define a reward function that returns a positive reward if the tool returns the correct answer. In the script we create a dummy reward function like `reward_fn = lambda x: 1`, but we override the rewards directly later.
|
||||
1. Create a prompt on how to use the tools
|
||||
```python
|
||||
# system prompt
|
||||
prompt = """\
|
||||
What is 13.1-3?
|
||||
|
||||
<request><SimpleCalculatorTool>13.1-3<call>10.1<response>
|
||||
|
||||
Result=10.1<submit>
|
||||
|
||||
What is 4*3?
|
||||
|
||||
<request><SimpleCalculatorTool>4*3<call>12<response>
|
||||
|
||||
Result=12<submit>
|
||||
|
||||
What is 12.1+1?
|
||||
|
||||
<request><SimpleCalculatorTool>12.1+1<call>13.1<response>
|
||||
|
||||
Result=13.1<submit>
|
||||
|
||||
What is 12.1-20?
|
||||
|
||||
<request><SimpleCalculatorTool>12.1-20<call>-7.9<response>
|
||||
|
||||
Result=-7.9<submit>"""
|
||||
```
|
||||
3. Create a `trl.TextEnvironment` with the model
|
||||
```python
|
||||
env = TextEnvironment(
|
||||
model,
|
||||
tokenizer,
|
||||
{"SimpleCalculatorTool": tool_fn},
|
||||
reward_fn,
|
||||
prompt,
|
||||
generation_kwargs=generation_kwargs,
|
||||
)
|
||||
```
|
||||
4. Then generate some data such as `tasks = ["\n\nWhat is 13.1-3?", "\n\nWhat is 4*3?"]` and run the environment with `queries, responses, masks, rewards, histories = env.run(tasks)`. The environment will look for the `<call>` token in the prompt and append the tool output to the response; it will also return the mask associated with the response. You can further use the `histories` to visualize the interaction between the model and the tool; `histories[0].show_text()` will show the text with color-coded tool output and `histories[0].show_tokens(tokenizer)` will show visualize the tokens.
|
||||

|
||||
1. Finally, we can train the model with `train_stats = ppo_trainer.step(queries, responses, rewards, masks)`. The trainer will use the mask to ignore the tool output when computing the loss, make sure to pass that argument to `step`.
|
||||
|
||||
## Experiment results
|
||||
|
||||
We trained a model with the above script for 10 random seeds. You can reproduce the run with the following command. Feel free to remove the `--slurm-*` arguments if you don't have access to a slurm cluster.
|
||||
|
||||
```
|
||||
WANDB_TAGS="calculator_final" python benchmark/benchmark.py \
|
||||
--command "python examples/calculator_few_shots_env.py" \
|
||||
--num-seeds 10 \
|
||||
--start-seed 1 \
|
||||
--workers 10 \
|
||||
--slurm-gpus-per-task 1 \
|
||||
--slurm-ntasks 1 \
|
||||
--slurm-total-cpus 8 \
|
||||
--slurm-template-path benchmark/trl.slurm_template
|
||||
```
|
||||
|
||||
We can then use [`openrlbenchmark`](https://github.com/openrlbenchmark/openrlbenchmark) which generates the following plot.
|
||||
```
|
||||
python -m openrlbenchmark.rlops_multi_metrics \
|
||||
--filters '?we=openrlbenchmark&wpn=trl&xaxis=_step&ceik=trl_ppo_trainer_config.value.tracker_project_name&cen=trl_ppo_trainer_config.value.log_with&metrics=env/reward_mean&metrics=objective/kl' \
|
||||
'wandb?tag=calculator_final&cl=calculator_mask' \
|
||||
--env-ids trl \
|
||||
--check-empty-runs \
|
||||
--pc.ncols 2 \
|
||||
--pc.ncols-legend 1 \
|
||||
--output-filename static/0compare \
|
||||
--scan-history
|
||||
```
|
||||
|
||||

|
||||
|
||||
As we can see, while 1-2 experiments crashed for some reason, most of the runs obtained near perfect proficiency in the calculator task.
|
||||
|
||||
|
||||
## (Early Experiments 🧪): learning to use a wiki tool for question answering
|
||||
|
||||
In the [ToolFormer](https://arxiv.org/abs/2302.04761) paper, it shows an interesting use case that utilizes a Wikipedia Search tool to help answer questions. In this section, we attempt to perform similar experiments but uses RL instead to teach the model to use a wiki tool on the [TriviaQA](https://nlp.cs.washington.edu/triviaqa/) dataset.
|
||||
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
**Note that many settings are different so the results are not directly comparable.**
|
||||
</Tip>
|
||||
|
||||
|
||||
|
||||
|
||||
### Building a search index
|
||||
|
||||
Since [ToolFormer](https://arxiv.org/abs/2302.04761) did not open source, we needed to first replicate the search index. It is mentioned in their paper that the authors built the search index using a BM25 retriever that indexes the Wikipedia dump from [KILT](https://github.com/facebookresearch/KILT)
|
||||
|
||||
Fortunately, [`pyserini`](https://github.com/castorini/pyserini) already implements the BM25 retriever and provides a prebuilt index for the KILT Wikipedia dump. We can use the following code to search the index.
|
||||
|
||||
```python
|
||||
from pyserini.search.lucene import LuceneSearcher
|
||||
import json
|
||||
searcher = LuceneSearcher.from_prebuilt_index('wikipedia-kilt-doc')
|
||||
def search(query):
|
||||
hits = searcher.search(query, k=1)
|
||||
hit = hits[0]
|
||||
contents = json.loads(hit.raw)['contents']
|
||||
return contents
|
||||
print(search("tennis racket"))
|
||||
```
|
||||
```
|
||||
Racket (sports equipment)
|
||||
A racket or racquet is a sports implement consisting of a handled frame with an open hoop across which a network of strings or catgut is stretched tightly. It is used for striking a ball or shuttlecock in games such as squash, tennis, racquetball, and badminton. Collectively, these games are known as racket sports. Racket design and manufacturing has changed considerably over the centuries.
|
||||
|
||||
The frame of rackets for all sports was traditionally made of solid wood (later laminated wood) and the strings of animal intestine known as catgut. The traditional racket size was limited by the strength and weight of the wooden frame which had to be strong enough to hold the strings and stiff enough to hit the ball or shuttle. Manufacturers started adding non-wood laminates to wood rackets to improve stiffness. Non-wood rackets were made first of steel, then of aluminum, and then carbon fiber composites. Wood is still used for real tennis, rackets, and xare. Most rackets are now made of composite materials including carbon fiber or fiberglass, metals such as titanium alloys, or ceramics.
|
||||
...
|
||||
```
|
||||
|
||||
We then basically deployed this snippet as a Hugging Face space [here](https://huggingface.co/spaces/vwxyzjn/pyserini-wikipedia-kilt-doc), so that we can use the space as a `transformers.Tool` later.
|
||||
|
||||

|
||||
|
||||
### Experiment settings
|
||||
|
||||
We use the following settings:
|
||||
|
||||
* use the `bigcode/starcoderbase` model as the base model
|
||||
* use the `pyserini-wikipedia-kilt-doc` space as the wiki tool and only uses the first paragrahs of the search result, allowing the `TextEnvironment` to obtain at most `max_tool_reponse=400` response tokens from the tool.
|
||||
* test if the response contain the answer string, if so, give a reward of 1, otherwise, give a reward of 0.
|
||||
* notice this is a simplified evaluation criteria. In [ToolFormer](https://arxiv.org/abs/2302.04761), the authors checks if the first 20 words of the response contain the correct answer.
|
||||
* used the following prompt that demonstrates the usage of the wiki tool.
|
||||
```python
|
||||
prompt = """\
|
||||
Answer the following question:
|
||||
|
||||
Q: In which branch of the arts is Patricia Neary famous?
|
||||
A: Ballets
|
||||
A2: <request><Wiki>Patricia Neary<call>Patricia Neary (born October 27, 1942) is an American ballerina, choreographer and ballet director, who has been particularly active in Switzerland. She has also been a highly successful ambassador for the Balanchine Trust, bringing George Balanchine's ballets to 60 cities around the globe.<response>
|
||||
Result=Ballets<submit>
|
||||
|
||||
Q: Who won Super Bowl XX?
|
||||
A: Chicago Bears
|
||||
A2: <request><Wiki>Super Bowl XX<call>Super Bowl XX was an American football game between the National Football Conference (NFC) champion Chicago Bears and the American Football Conference (AFC) champion New England Patriots to decide the National Football League (NFL) champion for the 1985 season. The Bears defeated the Patriots by the score of 46–10, capturing their first NFL championship (and Chicago's first overall sports victory) since 1963, three years prior to the birth of the Super Bowl. Super Bowl XX was played on January 26, 1986 at the Louisiana Superdome in New Orleans.<response>
|
||||
Result=Chicago Bears<submit>
|
||||
|
||||
Q: """
|
||||
```
|
||||
|
||||
|
||||
### Result and Discussion
|
||||
|
||||
|
||||
Our experiments show that the agent can learn to use the wiki tool to answer questions. The learning curves would go up mostly, but one of the experiment did crash.
|
||||
|
||||

|
||||
|
||||
Wandb report is [here](https://wandb.ai/costa-huang/cleanRL/reports/TriviaQA-Final-Experiments--Vmlldzo1MjY0ODk5) for further inspection.
|
||||
|
||||
|
||||
Note that the correct rate of the trained model is on the low end, which could be due to the following reasons:
|
||||
|
||||
* **incorrect searches:** When given the question `"What is Bruce Willis' real first name?"` if the model searches for `Bruce Willis`, our wiki tool returns "Patrick Poivey (born 18 February 1948) is a French actor. He is especially known for his voice: he is the French dub voice of Bruce Willis since 1988.` But a correct search should be `Walter Bruce Willis (born March 19, 1955) is an American former actor. He achieved fame with a leading role on the comedy-drama series Moonlighting (1985–1989) and appeared in over a hundred films, gaining recognition as an action hero after his portrayal of John McClane in the Die Hard franchise (1988–2013) and other roles.[1][2]"
|
||||
|
||||
|
||||

|
||||
|
||||
* **unnecessarily long response**: The wiki tool by default sometimes output very long sequences. E.g., when the wiki tool searches for "Brown Act"
|
||||
* Our wiki tool returns "The Ralph M. Brown Act, located at California Government Code 54950 "et seq.", is an act of the California State Legislature, authored by Assemblymember Ralph M. Brown and passed in 1953, that guarantees the public's right to attend and participate in meetings of local legislative bodies."
|
||||
* [ToolFormer](https://arxiv.org/abs/2302.04761)'s wiki tool returns "The Ralph M. Brown Act is an act of the California State Legislature that guarantees the public's right to attend and participate in meetings of local legislative bodies." which is more succinct.
|
||||
|
||||

|
||||
|
||||
|
||||
## (Early Experiments 🧪): solving math puzzles with python interpreter
|
||||
|
||||
In this section, we attempt to teach the model to use a python interpreter to solve math puzzles. The rough idea is to give the agent a prompt like the following:
|
||||
|
||||
```python
|
||||
prompt = """\
|
||||
Example of using a Python API to solve math questions.
|
||||
|
||||
Q: Olivia has $23. She bought five bagels for $3 each. How much money does she have left?
|
||||
|
||||
<request><PythonInterpreter>
|
||||
def solution():
|
||||
money_initial = 23
|
||||
bagels = 5
|
||||
bagel_cost = 3
|
||||
money_spent = bagels * bagel_cost
|
||||
money_left = money_initial - money_spent
|
||||
result = money_left
|
||||
return result
|
||||
print(solution())
|
||||
<call>72<response>
|
||||
|
||||
Result = 72 <submit>
|
||||
|
||||
Q: """
|
||||
|
||||
|
||||
Training results TBD.
|
@ -14,16 +14,62 @@ If you want to log with tensorboard, add the kwarg `project_kwargs={"logging_dir
|
||||
|
||||
## PPO Logging
|
||||
|
||||
Here's a brief explanation for the logged metrics provided in the data:
|
||||
|
||||
Key metrics to monitor. We want to maximize the reward, maintain a low KL divergence, and maximize entropy:
|
||||
1. `env/reward_mean`: The average reward obtained from the environment. Alias `ppo/mean_scores`, which is sed to specifically monitor the reward model.
|
||||
1. `env/reward_std`: The standard deviation of the reward obtained from the environment. Alias ``ppo/std_scores`, which is sed to specifically monitor the reward model.
|
||||
1. `env/reward_dist`: The histogram distribution of the reward obtained from the environment.
|
||||
1. `objective/kl`: The mean Kullback-Leibler (KL) divergence between the old and new policies. It measures how much the new policy deviates from the old policy. The KL divergence is used to compute the KL penalty in the objective function.
|
||||
1. `objective/kl_dist`: The histogram distribution of the `objective/kl`.
|
||||
1. `objective/kl_coef`: The coefficient for Kullback-Leibler (KL) divergence in the objective function.
|
||||
1. `ppo/mean_non_score_reward`: The **KL penalty** calculated by `objective/kl * objective/kl_coef` as the total reward for optimization to prevent the new policy from deviating too far from the old policy.
|
||||
1. `objective/entropy`: The entropy of the model's policy, calculated by `-logprobs.sum(-1).mean()`. High entropy means the model's actions are more random, which can be beneficial for exploration.
|
||||
|
||||
Training stats:
|
||||
1. `ppo/learning_rate`: The learning rate for the PPO algorithm.
|
||||
1. `ppo/policy/entropy`: The entropy of the model's policy, calculated by `pd = torch.nn.functional.softmax(logits, dim=-1); entropy = torch.logsumexp(logits, dim=-1) - torch.sum(pd * logits, dim=-1)`. It measures the randomness of the policy.
|
||||
1. `ppo/policy/clipfrac`: The fraction of probability ratios (old policy / new policy) that fell outside the clipping range in the PPO objective. This can be used to monitor the optimization process.
|
||||
1. `ppo/policy/approxkl`: The approximate KL divergence between the old and new policies, measured by `0.5 * masked_mean((logprobs - old_logprobs) ** 2, mask)`, corresponding to the `k2` estimator in http://joschu.net/blog/kl-approx.html
|
||||
1. `ppo/policy/policykl`: Similar to `ppo/policy/approxkl`, but measured by `masked_mean(old_logprobs - logprobs, mask)`, corresponding to the `k1` estimator in http://joschu.net/blog/kl-approx.html
|
||||
1. `ppo/policy/ratio`: The histogram distribution of the ratio between the new and old policies, used to compute the PPO objective.
|
||||
1. `ppo/policy/advantages_mean`: The average of the GAE (Generalized Advantage Estimation) advantage estimates. The advantage function measures how much better an action is compared to the average action at a state.
|
||||
1. `ppo/policy/advantages`: The histogram distribution of `ppo/policy/advantages_mean`.
|
||||
1. `ppo/returns/mean`: The mean of the TD(λ) returns, calculated by `returns = advantage + values`, another indicator of model performance. See https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/ for more details.
|
||||
1. `ppo/returns/var`: The variance of the TD(λ) returns, calculated by `returns = advantage + values`, another indicator of model performance.
|
||||
1. `ppo/val/mean`: The mean of the values, used to monitor the value function's performance.
|
||||
1. `ppo/val/var` : The variance of the values, used to monitor the value function's performance.
|
||||
1. `ppo/val/var_explained`: The explained variance for the value function, used to monitor the value function's performance.
|
||||
1. `ppo/val/clipfrac`: The fraction of the value function's predicted values that are clipped.
|
||||
1. `ppo/val/vpred`: The predicted values from the value function.
|
||||
1. `ppo/val/error`: The mean squared error between the `ppo/val/vpred` and returns, used to monitor the value function's performance.
|
||||
1. `ppo/loss/policy`: The policy loss for the Proximal Policy Optimization (PPO) algorithm.
|
||||
1. `ppo/loss/value`: The loss for the value function in the PPO algorithm. This value quantifies how well the function estimates the expected future rewards.
|
||||
1. `ppo/loss/total`: The total loss for the PPO algorithm. It is the sum of the policy loss and the value function loss.
|
||||
|
||||
|
||||
Stats on queries, responses, and logprobs:
|
||||
1. `tokens/queries_len_mean`: The average length of the queries tokens.
|
||||
1. `tokens/queries_len_std`: The standard deviation of the length of the queries tokens.
|
||||
1. `tokens/queries_dist`: The histogram distribution of the length of the queries tokens.
|
||||
1. `tokens/responses_len_mean`: The average length of the responses tokens.
|
||||
1. `tokens/responses_len_std`: The standard deviation of the length of the responses tokens.
|
||||
1. `tokens/responses_dist`: The histogram distribution of the length of the responses tokens. (Costa: inconsistent naming, should be `tokens/responses_len_dist`)
|
||||
1. `objective/logprobs`: The histogram distribution of the log probabilities of the actions taken by the model.
|
||||
1. `objective/ref_logprobs`: The histogram distribution of the log probabilities of the actions taken by the reference model.
|
||||
|
||||
|
||||
|
||||
### Crucial values
|
||||
During training, many values are logged, here are the most important ones:
|
||||
|
||||
1. `env/reward_mean`,`env/reward_std`, `env/reward_dist`: the properties of the reward distribution from the "environment".
|
||||
2. `ppo/mean_scores`: The mean scores directly out of the reward model.
|
||||
3. `ppo/mean_non_score_reward`: The mean negated KL penalty during training (shows the delta between the reference model and the new policy over the batch in the step)
|
||||
1. `env/reward_mean`,`env/reward_std`, `env/reward_dist`: the properties of the reward distribution from the "environment" / reward model
|
||||
1. `ppo/mean_non_score_reward`: The mean negated KL penalty during training (shows the delta between the reference model and the new policy over the batch in the step)
|
||||
|
||||
### Training stability parameters:
|
||||
Here are some parameters that are useful to monitor for stability (when these diverge or collapse to 0, try tuning variables):
|
||||
|
||||
1. `ppo/loss/value`: The value function loss -- will spike / NaN when not going well.
|
||||
2. `ppo/val/clipfrac`: The fraction of clipped values in the value function loss. This is often from 0.3 to 0.6.
|
||||
3. `objective/kl_coef`: The target coefficient with [`AdaptiveKLController`]. Often increases before numerical instabilities.
|
||||
1. `ppo/loss/value`: it will spike / NaN when not going well.
|
||||
1. `ppo/policy/ratio`: `ratio` being 1 is a baseline value, meaning that the probability of sampling a token is the same under the new and old policy. If the ratio is too high like 200, it means the probability of sampling a token is 200 times higher under the new policy than the old policy. This is a sign that the new policy is too different from the old policy, which will likely cause overoptimization and collapse training later on.
|
||||
1. `ppo/policy/clipfrac` and `ppo/policy/approxkl`: if `ratio` is too high, the `ratio` is going to get clipped, resulting in high `clipfrac` and high `approxkl` as well.
|
||||
1. `objective/kl`: it should stay positive so that the policy is not too far away from the reference policy.
|
||||
1. `objective/kl_coef`: The target coefficient with [`AdaptiveKLController`]. Often increases before numerical instabilities.
|
@ -3,13 +3,13 @@
|
||||
The notebooks and scripts in this examples show how to use Low Rank Adaptation (LoRA) to fine-tune models in a memory efficient manner. Most of PEFT methods supported in peft library but note that some PEFT methods such as Prompt tuning are not supported.
|
||||
For more information on LoRA, see the [original paper](https://arxiv.org/abs/2106.09685).
|
||||
|
||||
Here's an overview of the `peft`-enabled notebooks and scripts in the [trl repository](https://github.com/lvwerra/trl/tree/main/examples):
|
||||
Here's an overview of the `peft`-enabled notebooks and scripts in the [trl repository](https://github.com/huggingface/trl/tree/main/examples):
|
||||
|
||||
| File | Task | Description | Colab link |
|
||||
|---|---| --- |
|
||||
| [`stack_llama/rl_training.py`](https://github.com/lvwerra/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py) | RLHF | Distributed fine-tuning of the 7b parameter LLaMA models with a learned reward model and `peft`. | |
|
||||
| [`stack_llama/reward_modeling.py`](https://github.com/lvwerra/trl/blob/main/examples/research_projects/stack_llama/scripts/reward_modeling.py) | Reward Modeling | Distributed training of the 7b parameter LLaMA reward model with `peft`. | |
|
||||
| [`stack_llama/supervised_finetuning.py`](https://github.com/lvwerra/trl/blob/main/examples/research_projects/stack_llama/scripts/supervised_finetuning.py) | SFT | Distributed instruction/supervised fine-tuning of the 7b parameter LLaMA model with `peft`. | |
|
||||
| [`stack_llama/rl_training.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama/scripts/rl_training.py) | RLHF | Distributed fine-tuning of the 7b parameter LLaMA models with a learned reward model and `peft`. | |
|
||||
| [`stack_llama/reward_modeling.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama/scripts/reward_modeling.py) | Reward Modeling | Distributed training of the 7b parameter LLaMA reward model with `peft`. | |
|
||||
| [`stack_llama/supervised_finetuning.py`](https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama/scripts/supervised_finetuning.py) | SFT | Distributed instruction/supervised fine-tuning of the 7b parameter LLaMA model with `peft`. | |
|
||||
|
||||
## Installation
|
||||
Note: peft is in active development, so we install directly from their Github page.
|
||||
|
@ -11,10 +11,10 @@ You just need to install `peft` and optionally install `bitsandbytes` as well if
|
||||
You need to address this approach in three stages that we summarize as follows:
|
||||
|
||||
1- Train a base model on the target domain (e.g. `imdb` dataset) - this is the Supervised Fine Tuning stage - it can leverage the `SFTTrainer` from TRL.
|
||||
2- Train a reward model using `peft`. This is required in order to re-use the adapter during the RL optimisation process (step 3 below). We show an example of leveraging the `RewardTrainer` from TRL in [this example](https://github.com/lvwerra/trl/tree/main/examples/scripts/reward_trainer.py)
|
||||
2- Train a reward model using `peft`. This is required in order to re-use the adapter during the RL optimisation process (step 3 below). We show an example of leveraging the `RewardTrainer` from TRL in [this example](https://github.com/huggingface/trl/tree/main/examples/scripts/reward_trainer.py)
|
||||
3- Fine tune new adapters on the base model using PPO and the reward adapter. ("0 abstraction RL")
|
||||
|
||||
Make sure to use the same model (i.e. same architecure and same weights) for the stages 2 & 3.
|
||||
Make sure to use the same model (i.e. same architecture and same weights) for the stages 2 & 3.
|
||||
|
||||
## Quickstart
|
||||
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
TRL supports custom reward modeling for anyone to perform reward modeling on their dataset and model.
|
||||
|
||||
Check out a complete flexible example inside [`examples/scripts`](https://github.com/lvwerra/trl/tree/main/examples/scripts/reward_trainer.py) folder.
|
||||
Check out a complete flexible example inside [`examples/scripts`](https://github.com/huggingface/trl/tree/main/examples/scripts/reward_trainer.py) folder.
|
||||
|
||||
## Expected dataset format
|
||||
|
||||
@ -23,7 +23,7 @@ The `j` and `k` suffixes are used to denote the two sentences in the paired data
|
||||
|
||||
## Using the `RewardTrainer`
|
||||
|
||||
After standardizing your dataset, you can use the `RewardTrainer` as a classic HugingFace Trainer.
|
||||
After standardizing your dataset, you can use the `RewardTrainer` as a classic Hugging Face Trainer.
|
||||
You should pass an `AutoModelForSequenceClassification` model to the `RewardTrainer`.
|
||||
|
||||
### Leveraging the `peft` library to train a reward model
|
||||
|
@ -2,15 +2,15 @@
|
||||
|
||||
The notebooks and scripts in this examples show how to fine-tune a model with a sentiment classifier (such as `lvwerra/distilbert-imdb`).
|
||||
|
||||
Here's an overview of the notebooks and scripts in the [trl repository](https://github.com/lvwerra/trl/tree/main/examples):
|
||||
Here's an overview of the notebooks and scripts in the [trl repository](https://github.com/huggingface/trl/tree/main/examples):
|
||||
|
||||
| File | Description | Colab link |
|
||||
|---|---| --- |
|
||||
| [`gpt2-sentiment.ipynb`](https://github.com/lvwerra/trl/blob/main/examples/notebooks/gpt2-sentiment.ipynb) | Fine-tune GPT2 to generate positive movie reviews. | [](https://colab.research.google.com/github/lvwerra/trl/blob/main/examples/sentiment/notebooks/gpt2-sentiment.ipynb)
|
||||
| [`gpt2-sentiment.ipynb`](https://github.com/huggingface/trl/blob/main/examples/notebooks/gpt2-sentiment.ipynb) | Fine-tune GPT2 to generate positive movie reviews. | [](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/sentiment/notebooks/gpt2-sentiment.ipynb)
|
||||
|
|
||||
| [`gpt2-sentiment-control.ipynb`](https://github.com/lvwerra/trl/blob/main/examples/notebooks/gpt2-sentiment-control.ipynb) | Fine-tune GPT2 to generate movie reviews with controlled sentiment. | [](https://colab.research.google.com/github/lvwerra/trl/blob/main/examples/sentiment/notebooks/gpt2-sentiment-control.ipynb)
|
||||
| [`gpt2-sentiment-control.ipynb`](https://github.com/huggingface/trl/blob/main/examples/notebooks/gpt2-sentiment-control.ipynb) | Fine-tune GPT2 to generate movie reviews with controlled sentiment. | [](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/sentiment/notebooks/gpt2-sentiment-control.ipynb)
|
||||
|
|
||||
| [`gpt2-sentiment.py`](https://github.com/lvwerra/trl/blob/main/examples/ppo_trainer/sentiment_tuning.py) | Same as the notebook, but easier to use to use in multi-GPU setup with any architecture. | x |
|
||||
| [`gpt2-sentiment.py`](https://github.com/huggingface/trl/blob/main/examples/ppo_trainer/sentiment_tuning.py) | Same as the notebook, but easier to use to use in multi-GPU setup with any architecture. | x |
|
||||
|
||||
|
||||
## Installation
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
Supervised fine-tuning (or SFT for short) is a crucial step in RLHF. In TRL we provide an easy-to-use API to create your SFT models and train them with few lines of code on your dataset.
|
||||
|
||||
Check out a complete flexible example inside [`examples/scripts`](https://github.com/lvwerra/trl/tree/main/examples/scripts/sft_trainer.py) folder.
|
||||
Check out a complete flexible example inside [`examples/scripts`](https://github.com/huggingface/trl/tree/main/examples/scripts/sft_trainer.py) folder.
|
||||
|
||||
## Quickstart
|
||||
|
||||
@ -46,7 +46,7 @@ trainer = SFTTrainer(
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
The above snippets will use the default training arguments from the [`transformers.TrainingArguments`](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments) class. If you want to modify that, make sure to create your own `TrainingArguments` object and pass it to the [`SFTTrainer`] constructor as it is done on the [`supervised_finetuning.py` script](https://github.com/lvwerra/trl/blob/main/examples/stack_llama/scripts/supervised_finetuning.py) on the stack-llama example.
|
||||
The above snippets will use the default training arguments from the [`transformers.TrainingArguments`](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments) class. If you want to modify that, make sure to create your own `TrainingArguments` object and pass it to the [`SFTTrainer`] constructor as it is done on the [`supervised_finetuning.py` script](https://github.com/huggingface/trl/blob/main/examples/stack_llama/scripts/supervised_finetuning.py) on the stack-llama example.
|
||||
|
||||
## Advanced usage
|
||||
|
||||
@ -111,6 +111,47 @@ trainer = SFTTrainer(
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
#### Using token_ids directly for `response_template`
|
||||
|
||||
Some tokenizers like Llama 2 (`meta-llama/Llama-2-XXb-hf`) tokenize sequences differently depending whether they have context or not. For example:
|
||||
|
||||
```python
|
||||
from transformers import AutoTokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
|
||||
|
||||
def print_tokens_with_ids(txt):
|
||||
tokens = tokenizer.tokenize(txt, add_special_tokens=False)
|
||||
token_ids = tokenizer.encode(txt, add_special_tokens=False)
|
||||
print(list(zip(tokens, token_ids)))
|
||||
|
||||
prompt = """### User: Hello\n\n### Assistant: Hi, how can I help you?"""
|
||||
print_tokens_with_ids(prompt) # [..., ('▁Hello', 15043), ('<0x0A>', 13), ('<0x0A>', 13), ('##', 2277), ('#', 29937), ('▁Ass', 4007), ('istant', 22137), (':', 29901), ...]
|
||||
|
||||
response_template = "### Assistant:"
|
||||
print_tokens_with_ids(response_template) # [('▁###', 835), ('▁Ass', 4007), ('istant', 22137), (':', 29901)]
|
||||
```
|
||||
|
||||
In this case, and due to lack of context in `response_template`, the same string ("### Assistant:") is tokenized differently:
|
||||
|
||||
- Text (with context): `[2277, 29937, 4007, 22137, 29901]`
|
||||
- `response_template` (without context): `[835, 4007, 22137, 29901]`
|
||||
|
||||
This will lead to an error when the `DataCollatorForCompletionOnlyLM` does not find the `response_template` in the dataset example text:
|
||||
|
||||
```
|
||||
RuntimeError: Could not find response key [835, 4007, 22137, 29901] in token IDs tensor([ 1, 835, ...])
|
||||
```
|
||||
|
||||
|
||||
To solve this, you can tokenize the `response_template` with the same context than in the dataset, truncate it as needed and pass the `token_ids` directly to the `response_template` argument of the `DataCollatorForCompletionOnlyLM` class. For example:
|
||||
|
||||
```python
|
||||
response_template_with_context = "\n### Assistant:" # We added context here: "\n". This is enough for this tokenizer
|
||||
response_template_ids = tokenizer.encode(response_template_with_context, add_special_tokens=False)[2:] # Now we have it like in the dataset texts: `[2277, 29937, 4007, 22137, 29901]`
|
||||
|
||||
data_collator = DataCollatorForCompletionOnlyLM(response_template_ids, tokenizer=tokenizer)
|
||||
```
|
||||
|
||||
### Format your input prompts
|
||||
|
||||
For instruction fine-tuning, it is quite common to have two columns inside the dataset: one for the prompt & the other for the response.
|
||||
@ -142,7 +183,7 @@ trainer = SFTTrainer(
|
||||
|
||||
trainer.train()
|
||||
```
|
||||
To preperly format your input make sure to process all the examples by looping over them and returning a list of processed text. Check out a full example on how to use SFTTrainer on alpaca dataset [here](https://github.com/lvwerra/trl/pull/444#issue-1760952763)
|
||||
To preperly format your input make sure to process all the examples by looping over them and returning a list of processed text. Check out a full example on how to use SFTTrainer on alpaca dataset [here](https://github.com/huggingface/trl/pull/444#issue-1760952763)
|
||||
|
||||
### Packing dataset ([`ConstantLengthDataset`])
|
||||
|
||||
@ -185,7 +226,7 @@ You can also customize the [`ConstantLengthDataset`] much more by directly passi
|
||||
|
||||
### Control over the pretrained model
|
||||
|
||||
You can directly pass the kwargs of the `from_pretrained()` method to the [`SFTTrainer`]. For example, if you want to load a model in a different precision, analoguous to
|
||||
You can directly pass the kwargs of the `from_pretrained()` method to the [`SFTTrainer`]. For example, if you want to load a model in a different precision, analogous to
|
||||
|
||||
```python
|
||||
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=torch.bfloat16)
|
||||
|
197
docs/source/text_environments.md
Normal file
197
docs/source/text_environments.md
Normal file
@ -0,0 +1,197 @@
|
||||
# Text Environments
|
||||
|
||||
Text environments provide a learning ground for language agents. It allows a language model to use tools to accomplish a task such as using a Python interpreter to answer math questions or using a search index for trivia questions. Having access to tools allows language models to solve tasks that would be very hard for the models itself but can be trivial for the appropriate tools. A good example is arithmetics of large numbers that become a simple copy-paste task once you have access to a calculator.
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/textenv.png">
|
||||
</div>
|
||||
|
||||
Let's dive into how text environments work and start with tools!
|
||||
|
||||
## Tools
|
||||
|
||||
One of the core building blocks of text environments are tools that the model can use to solve tasks. In general tools can be any Python function that takes a string as input and returns string. The `TextEnvironment` offers two options for tools: either go with predefined tools from `transformers.Tool` or define your own function or class with `__call__` method. Let's have a look at both!
|
||||
|
||||
### `transformers.Tool`
|
||||
|
||||
Text environments fully support tools of the class `transformers.Tool`. The advantage of building tools in that framework is that they can easily be shared
|
||||
|
||||
```Python
|
||||
from transformers import load_tool
|
||||
|
||||
# simple calculator tool that runs +-/* operations
|
||||
calc_tool = load_tool("ybelkada/simple-calculator")
|
||||
|
||||
# python interpreter that executes program and returns outputs
|
||||
py_tool = load_tool("lvwerra/python-interpreter")
|
||||
|
||||
# wikipedia search index that returns best search match
|
||||
wiki_tool = load_tool("vwxyzjn/pyserini-wikipedia-kilt-doc")
|
||||
```
|
||||
|
||||
These tools are either loaded from the hub or from a local folder. Using the tool is as simple as calling them with a text query:
|
||||
|
||||
```Python
|
||||
calc_tool("1/2")
|
||||
>>> "0.5"
|
||||
```
|
||||
|
||||
Note that both input and return values are strings to enable easy usage with a language model.
|
||||
|
||||
### Custom Tools
|
||||
|
||||
The following is an example of a tool that adds two integers:
|
||||
|
||||
```Python
|
||||
def add(text):
|
||||
int_1, int_2 = text.split("+")
|
||||
result = int(int_1) + int(int_2)
|
||||
return str(result)
|
||||
|
||||
print(add("1+1"))
|
||||
>>> "2"
|
||||
```
|
||||
|
||||
We looked at basic examples such as a calculator but the principle holds for more complex tools as well such as a web search tool where you input the query and get the search results in return. Now let's look at how the model can use the tools with the call syntax.
|
||||
|
||||
### Call syntax
|
||||
|
||||
In order to have a unified way for the model to call a tool we created a simple syntax that looks as follows:
|
||||
|
||||
```python
|
||||
"<request><TOOL_NAME>QUERY<call>TOOL_RESPONSE<response>"
|
||||
```
|
||||
|
||||
There are a few special tokens involved so let's decompose it: First the model can signal that it wants to use a tool by emitting the `<request>` token. After that we want to know the name of the tool to call which is done by enclosing the tool name with `<>` brackets. Once we know which tool to call the tool query follows which is in free text form. The `<call>` tokens signifies the end of the query and stops the model generation. At this point the model output is parsed and the query sent to the tool. The environment appends the tool response to the string followed by the `<response>` token to show the end the tool output.
|
||||
|
||||
Let's look at the concrete example of the calculator and assume its name is `Calculator` (more on how the name of a tool is inferred later):
|
||||
|
||||
```python
|
||||
"<request><Calculator>1/2<call>0.5<response>"
|
||||
```
|
||||
|
||||
Finally, the episode is ended and generation stops when the model generates `<submit>` which marks the interaction as completed.
|
||||
|
||||
Now let's have a look how we can create a new text environment!
|
||||
|
||||
## Create a `TextEnvironment`
|
||||
|
||||
|
||||
```python
|
||||
prompt = """\
|
||||
What is 13-3?
|
||||
<request><SimpleCalculatorTool>13-3<call>10.0<response>
|
||||
Result=10<submit>
|
||||
"""
|
||||
|
||||
def reward_fn(result, answer):
|
||||
"""Simplified reward function returning 1 if result matches answer and 0 otherwise."""
|
||||
result_parsed = result.split("=")[1].split("<")[0]
|
||||
return int(result_parsed==answer)
|
||||
|
||||
text_env = TextEnvironemnt(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
tools= {"SimpleCalculatorTool": load_tool("ybelkada/simple-calculator")},
|
||||
reward_fn=exact_match_reward,
|
||||
prompt=prompt,
|
||||
max_turns=1
|
||||
max_tool_response=100
|
||||
generation_kwargs={"do_sample": "true"}
|
||||
)
|
||||
```
|
||||
|
||||
Let's decompose the settings:
|
||||
|
||||
| Argument | Description |
|
||||
|:-------------------|:----------------|
|
||||
| `model` | Language model to interact with the environment and generate requests. |
|
||||
| `tokenizer` | Tokenizer of language model handling tokenization of strings. |
|
||||
| `tools` | `list` of `dict` of tools. If former the name of the tool is inferred from class name and otherwise it's the keys of the dictionary.|
|
||||
| `reward_fn` | A function that takes a string as input and returns. Can have extra arguments that are passed to `.run()` such as ground truth.|
|
||||
| `prompt` | Prompt to prepend to every task. Usually a few examples to demonstrate to the model how to use the tools in a few-shot fashion. |
|
||||
| `max_turns` | Maximum number of interactions between model and tools before episode ends.|
|
||||
| `max_tool_response`| The tool response is truncated to this number to avoid running out of model context.|
|
||||
| `max_length` | The maximum number of tokens to allow in an episode. |
|
||||
| `generation_kwargs`| Generation settings used by the language model. |
|
||||
|
||||
You can customize the environment to your needs and add custom tools and settings. Let's see how you can use the environment to have the model interact with the available tools!
|
||||
|
||||
|
||||
## Run an Episode
|
||||
|
||||
To run a set of queries through the text environment one can simply use the `run` method.
|
||||
|
||||
```python
|
||||
queries = ["What is 1/2?"]
|
||||
answers = ["0.5"]
|
||||
|
||||
queries, responses, masks, rewards, histories = text_env.run(queries, answers=answers)
|
||||
```
|
||||
|
||||
This will execute the model/tool feedback loop for each query until either no tool is called anymore, the maximum number of turns is reached or to maximum number of tokens in an episode is exceeded. The extra `kwargs` (e.g. `answers=answers` above) passed to `run` will be passed on to the reward function.
|
||||
|
||||
There are five objects that are returned by `run`:
|
||||
|
||||
- `queries`: a list of the tokenized queries
|
||||
- `responses`: all tokens that have been generated withing the environment including model and tool tokens
|
||||
- `masks`: mask that indicates which tokens have been generated by the model and which tokens are generated by the tool
|
||||
- `rewards`: a list of reward for each query/response
|
||||
- `histories`: list of `TextHistory` objects, which are useful objects containing all the above and also the text equivalents
|
||||
|
||||
The masks are crucial for training as we don't want to optimize tokens that the model has not generated which are tokens produced by the tools.
|
||||
|
||||
Next, we'll train a PPO step with the generated responses!
|
||||
|
||||
|
||||
### Train
|
||||
Training on episodes from the `TextEnvironment` is straight forward and simply requires forwarding all the returned variables except the `TextHistory` objects to the `step` method:
|
||||
|
||||
```python
|
||||
train_stats = ppo_trainer.step(queries, responses, rewards, masks)
|
||||
```
|
||||
|
||||
## `TextHistory`
|
||||
|
||||
The `TextHistory` object stores the interactions between the model and the text environment. It stores tokens and text generated in each turn and their source in each turn (model or system) as well as rewards. Let's go through the class attributes and methods.
|
||||
|
||||
### Attributes
|
||||
|
||||
The following table summarises the available attributes of the `TextEnvironment` class:
|
||||
|
||||
| Attribute | Description |
|
||||
|:-------------------|:----------------|
|
||||
| `text` | The full string of the text generated in the text environment with both model and system generated text. |
|
||||
| `text_spans` | A list of tuples with the spans for each model or system generated text segment. |
|
||||
| `system_spans` | A list of boolean values indicating if the segment is model or system generated. |
|
||||
| `tokens` | All tokens generated in text environment with both model and system generated tokens. |
|
||||
| `token_spans` | Similar to `text_spans` the `token_spans` indicate the boundaries of model andsystem generated tokens. |
|
||||
| `token_masks` | The token masks can be used to ignore system generated tokens by masking them. |
|
||||
| `completed` | Indicates if the interaction with the environment has completed. |
|
||||
| `truncated` | Indicates if the interaction with the environment has completed because max length was reached. |
|
||||
|
||||
With these attributes you can reconstruct every interaction of the model with the `TextEnvironment`. The `TextHistory` also lets you visualize the text history. Let's have a look!
|
||||
|
||||
### Visualization
|
||||
|
||||
When the model interacts inside the `TextEnvironment` it can be useful to visualize and separate which parts of the text outputs were generated by the model and which parts come from the system and tools. For that purpose there are the two methods [`TextHistory.show_text`] and [`TextHistory.show_tokens`]. They print the text and tokens respectively and highlight the various segments using the [`rich` libray](https://github.com/Textualize/rich) (make sure to install it before using these methods).
|
||||
|
||||
You can see that the prompt is highlighted in gray, whereas system segments such as query and tool responses are highlighted in green. All segments generated by the model are highlighted in blue and in addition to the pure text output the reward is displayed as additional text in plum. Here an example of `show_text`:
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/textenv_show_text.png" width=600>
|
||||
</div>
|
||||
|
||||
Sometimes there can be tricky tokenization related issues that are hidden when showing the decoded text. Thus `TextHistory` also offers an option to display the same highlighting on the tokens directly with `show_tokens`:
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/textenv_show_tokens.png" width=800>
|
||||
</div>
|
||||
|
||||
Note that you can turn on the colour legend by passing `show_legend=True`.
|
||||
|
||||
## API Documentation
|
||||
|
||||
[[autodoc]] TextEnvironment
|
||||
|
||||
[[autodoc]] TextHistory
|
@ -24,6 +24,14 @@ We also support a `RewardTrainer` that can be used to train a reward model.
|
||||
|
||||
[[autodoc]] DPOTrainer
|
||||
|
||||
## DDPOConfig
|
||||
|
||||
[[autodoc]] DDPOConfig
|
||||
|
||||
## DDPOTrainer
|
||||
|
||||
[[autodoc]] DDPOTrainer
|
||||
|
||||
## set_seed
|
||||
|
||||
[[autodoc]] set_seed
|
||||
|
58
docs/source/use_model.md
Normal file
58
docs/source/use_model.md
Normal file
@ -0,0 +1,58 @@
|
||||
# Use model after training
|
||||
|
||||
Once you have trained a model using either the SFTTrainer, PPOTrainer, or DPOTrainer, you will have a fine-tuned model that can be used for text generation. In this section, we'll walk through the process of loading the fine-tuned model and generating text. If you need to run an inference server with the trained model, you can explore libraries such as [`text-generation-inference`](https://github.com/huggingface/text-generation-inference).
|
||||
|
||||
## Load and Generate
|
||||
|
||||
If you have fine-tuned a model fully, meaning without the use of PEFT you can simply load it like any other language model in transformers. E.g. the value head that was trained during the PPO training is no longer needed and if you load the model with the original transformer class it will be ignored:
|
||||
|
||||
```python
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
model_name_or_path = "kashif/stack-llama-2" #path/to/your/model/or/name/on/hub
|
||||
device = "cpu" # or "cuda" if you have a GPU
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name_or_path).to(device)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
||||
|
||||
inputs = tokenizer.encode("This movie was really", return_tensors="pt").to(device)
|
||||
outputs = model.generate(inputs)
|
||||
print(tokenizer.decode(outputs[0]))
|
||||
```
|
||||
|
||||
Alternatively you can also use the pipeline:
|
||||
|
||||
```python
|
||||
from transformers import pipeline
|
||||
|
||||
model_name_or_path = "kashif/stack-llama-2" #path/to/your/model/or/name/on/hub
|
||||
pipe = pipeline("text-generation", model=model_name_or_path)
|
||||
print(pipe("This movie was really")[0]["generated_text"])
|
||||
```
|
||||
|
||||
## Use Adapters PEFT
|
||||
|
||||
```python
|
||||
from peft import PeftConfig, PeftModel
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
base_model_name = "kashif/stack-llama-2" #path/to/your/model/or/name/on/hub"
|
||||
adapter_model_name = "path/to/my/adapter"
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(base_model_name)
|
||||
model = PeftModel.from_pretrained(model, adapter_model_name)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
|
||||
```
|
||||
|
||||
You can also merge the adapters into the base model so you can use the model like a normal transformers model, however the checkpoint will be significantly bigger:
|
||||
|
||||
```python
|
||||
model = AutoModelForCausalLM.from_pretrained(base_model_name)
|
||||
model = PeftModel.from_pretrained(model, adapter_model_name)
|
||||
|
||||
model = model.merge_and_unload()
|
||||
model.save_pretrained("merged_adapters")
|
||||
```
|
||||
|
||||
Once you have the model loaded and either merged the adapters or keep them separately on top you can run generation as with a normal model outlined above.
|
@ -157,4 +157,4 @@ for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
|
||||
ppo_trainer.log_stats(stats, batch, rewards)
|
||||
```
|
||||
|
||||
For the rest of the details adn evaluation, please refer to our [blog post on StackLLaMA](https://huggingface.co/blog/stackllama).
|
||||
For the rest of the details and evaluation, please refer to our [blog post on StackLLaMA](https://huggingface.co/blog/stackllama).
|
@ -29,7 +29,7 @@ pip install trl
|
||||
pip install wandb
|
||||
```
|
||||
Note: if you don't want to log with `wandb` remove `log_with="wandb"` in the scripts/notebooks.
|
||||
You can also replace it with your favourite experiment tracker that's [supported by `accelerate`](https://huggingface.co/docs/accelerate/usage_guides/tracking).
|
||||
You can also replace it with your favorite experiment tracker that's [supported by `accelerate`](https://huggingface.co/docs/accelerate/usage_guides/tracking).
|
||||
|
||||
## Accelerate Config
|
||||
For all the examples, you'll need to generate an `Accelerate` config with:
|
||||
@ -41,10 +41,10 @@ accelerate config # will prompt you to define the training configuration
|
||||
Then, it is encouraged to launch jobs with `accelerate launch`!
|
||||
|
||||
## Categories
|
||||
The examples are currently split over the following categories:
|
||||
The examples are currently split into the following categories:
|
||||
|
||||
**1: [ppo_trainer](https://github.com/lvwerra/trl/tree/main/examples/scripts/sentiment_tuning.py)**: Learn about different ways of using PPOTrainer
|
||||
**2: [sft_trainer](https://github.com/lvwerra/trl/tree/main/examples/scripts/sft_trainer.py)**: Learn about how to leverage `SFTTrainer` for supervised fine-tuning your pretrained language models easily.
|
||||
**3: [reward_modeling](https://github.com/lvwerra/trl/tree/main/examples/scripts/reward_trainer.py)**: Learn about how to use `RewardTrainer` to easily train your own reward model to use it for your RLHF pipeline.
|
||||
**4: [research_projects](https://github.com/lvwerra/trl/tree/main/examples/research_projects)**: Check out that folder to check the scripts used for some research projects that used TRL (LM de-toxification, Stack-Llama, etc.)
|
||||
**5: [notebooks](https://github.com/lvwerra/trl/tree/main/examples/notebooks)**: Check out that folder to check some applications of TRL features directly on a Jupyter notebook. This includes running sentiment tuning and sentiment control on a notebook, but also how to use "Best of N sampling" strategy using TRL.
|
||||
1. **[ppo_trainer](https://github.com/huggingface/trl/tree/main/examples/scripts/sentiment_tuning.py)**: Learn about different ways of using PPOTrainer
|
||||
1. **[sft_trainer](https://github.com/huggingface/trl/tree/main/examples/scripts/sft_trainer.py)**: Learn about how to leverage `SFTTrainer` for supervised fine-tuning your pretrained language models easily.
|
||||
1. **[reward_modeling](https://github.com/huggingface/trl/tree/main/examples/scripts/reward_trainer.py)**: Learn about how to use `RewardTrainer` to easily train your own reward model to use it for your RLHF pipeline.
|
||||
1. **[research_projects](https://github.com/huggingface/trl/tree/main/examples/research_projects)**: Check out this folder to find the scripts used for some research projects that used TRL (LM de-toxification, Stack-Llama, etc.)
|
||||
1. **[notebooks](https://github.com/huggingface/trl/tree/main/examples/notebooks)**: Check out this folder for some applications of TRL features directly on a Jupyter notebook. This includes running sentiment tuning and sentiment control on a notebook and how to use the "Best of N sampling" strategy using TRL.
|
||||
|
@ -101,6 +101,7 @@ if __name__ == "__main__":
|
||||
]
|
||||
|
||||
model_ref = AutoModelForCausalLM.from_pretrained(script_args.model_name_or_path)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(script_args.model_name_or_path)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
|
@ -2,6 +2,6 @@
|
||||
|
||||
This directory contains a collection of Jupyter notebooks that demonstrate how to use the TRL library in different applications.
|
||||
|
||||
- [`best_of_n.ipynb`](https://github.com/lvwerra/trl/tree/main/examples/notebooks/best_of_n.ipynb): This notebook demonstrates how to use the "Best of N" sampling strategy using TRL when fine-tuning your model with PPO.
|
||||
- [`gpt2-sentiment.ipynb`](https://github.com/lvwerra/trl/tree/main/examples/notebooks/gpt2-sentiment.ipynb): This notebook demonstrates how to reproduce the GPT2 imdb sentiment tuning example on a jupyter notebook.
|
||||
- [`gpt2-control.ipynb`](https://github.com/lvwerra/trl/tree/main/examples/notebooks/gpt2-control.ipynb): This notebook demonstrates how to reproduce the GPT2 sentiment control exampel on a jupyter notebook.
|
||||
- [`best_of_n.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/best_of_n.ipynb): This notebook demonstrates how to use the "Best of N" sampling strategy using TRL when fine-tuning your model with PPO.
|
||||
- [`gpt2-sentiment.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-sentiment.ipynb): This notebook demonstrates how to reproduce the GPT2 imdb sentiment tuning example on a jupyter notebook.
|
||||
- [`gpt2-control.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-control.ipynb): This notebook demonstrates how to reproduce the GPT2 sentiment control exampel on a jupyter notebook.
|
@ -847,7 +847,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.16"
|
||||
"version": "3.9.12"
|
||||
},
|
||||
"vscode": {
|
||||
"interpreter": {
|
||||
|
@ -398,7 +398,7 @@
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Training progress\n",
|
||||
"If you are tracking the training progress with Weights&Biases you should see a plot similar to the one below. Check out the interactive sample report on wandb.ai: [link](https://app.wandb.ai/lvwerra/trl-showcase/runs/1jtvxb1m/).\n",
|
||||
"If you are tracking the training progress with Weights&Biases you should see a plot similar to the one below. Check out the interactive sample report on wandb.ai: [link](https://app.wandb.ai/huggingface/trl-showcase/runs/1jtvxb1m/).\n",
|
||||
"\n",
|
||||
"<div style=\"text-align: center\">\n",
|
||||
"<img src='https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/gpt2_tuning_progress.png' width='800'>\n",
|
||||
|
@ -1,6 +1,7 @@
|
||||
# Research projects that uses TRL
|
||||
# Research projects that use TRL
|
||||
|
||||
Welcome to the research projects folder! Here you can find the scripts used for some research projects that used TRL and maintained by the developpers and the community (LM de-toxification, Stack-Llama, etc.). Check out the READMEs in the subfolders for more information!
|
||||
Welcome to the research projects folder! Here you can find the scripts used for some research projects that used TRL and maintained by the developers and the community (LM de-toxification, Stack-Llama, etc.). Check out the READMEs in the subfolders for more information!
|
||||
|
||||
- [De-detoxifying language models](https://github.com/lvwerra/trl/tree/main/examples/research_projects/toxicity)
|
||||
- [Stack-Llama](https://github.com/lvwerra/trl/tree/main/examples/research_projects/stack_llama)
|
||||
- [De-detoxifying language models](https://github.com/huggingface/trl/tree/main/examples/research_projects/toxicity)
|
||||
- [Stack-Llama](https://github.com/huggingface/trl/tree/main/examples/research_projects/stack_llama)
|
||||
- [Stack-Llama-2](https://github.com/huggingface/trl/tree/main/examples/research_projects/stack_llama_2)
|
51
examples/research_projects/stack_llama_2/scripts/README.md
Normal file
51
examples/research_projects/stack_llama_2/scripts/README.md
Normal file
@ -0,0 +1,51 @@
|
||||
# DPO pipeline for the creation of StackLlaMa 2: a Stack exchange llama-v2-7b model
|
||||
|
||||
## Prerequisites
|
||||
|
||||
Install all the dependencies in the `requirements.txt`:
|
||||
|
||||
```
|
||||
$ pip install -U -r requirements.txt
|
||||
```
|
||||
|
||||
Since we will use `accelerate` for training, make sure to run:
|
||||
```
|
||||
$ accelerate config
|
||||
```
|
||||
|
||||
## Training
|
||||
|
||||
There were two main steps to the DPO training process:
|
||||
1. Supervised fine-tuning of the base llama-v2-7b model to create llama-v2-7b-se:
|
||||
- `accelerate launch examples/stack_llama_2/scripts/sft_llama2.py --output_dir="sft"`
|
||||
1. Run the DPO trainer using the model saved by the previous step:
|
||||
- `accelerate launch examples/stack_llama_2/scripts/dpo_llama2.py --model_name_or_path="sft/final_checkpoint" --output_dir="dpo"`
|
||||
|
||||
|
||||
## Merging the adaptors
|
||||
|
||||
To merge the adaptors into the base model we can use the `merge_peft_adapter.py` helper script that comes with TRL:
|
||||
|
||||
```
|
||||
python trl/examples/research_projects/stack_llama/scripts/merge_peft_adapter.py --base_model_name="meta-llama/Llama-2-7b-hf" --adapter_model_name="dpo/final_checkpoint/" --output_name="stack-llama-2"
|
||||
```
|
||||
|
||||
which will also push the model to your HuggingFace hub account.
|
||||
|
||||
## Running the model
|
||||
|
||||
We can load the DPO-trained LoRA adaptors which were saved by the DPO training step and load them via:
|
||||
|
||||
```py
|
||||
from peft import AutoPeftModelForCausalLM
|
||||
|
||||
|
||||
model = AutoPeftModelForCausalLM.from_pretrained(
|
||||
"dpo/final_checkpoint",
|
||||
low_cpu_mem_usage=True,
|
||||
torch_dtype=torch.float16,
|
||||
load_in_4bit=True,
|
||||
)
|
||||
|
||||
model.generate(...)
|
||||
```
|
223
examples/research_projects/stack_llama_2/scripts/dpo_llama2.py
Normal file
223
examples/research_projects/stack_llama_2/scripts/dpo_llama2.py
Normal file
@ -0,0 +1,223 @@
|
||||
# 0. imports
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
from datasets import Dataset, load_dataset
|
||||
from peft import AutoPeftModelForCausalLM, LoraConfig
|
||||
from transformers import AutoTokenizer, HfArgumentParser, TrainingArguments
|
||||
|
||||
from trl import DPOTrainer
|
||||
|
||||
|
||||
# Define and parse arguments.
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
"""
|
||||
The arguments for the DPO training script.
|
||||
"""
|
||||
|
||||
# data parameters
|
||||
beta: Optional[float] = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"})
|
||||
|
||||
# training parameters
|
||||
model_name_or_path: Optional[str] = field(
|
||||
default="../sft/results/final_checkpoint",
|
||||
metadata={"help": "the location of the SFT model name or path"},
|
||||
)
|
||||
learning_rate: Optional[float] = field(default=5e-4, metadata={"help": "optimizer learning rate"})
|
||||
lr_scheduler_type: Optional[str] = field(default="cosine", metadata={"help": "the lr scheduler type"})
|
||||
warmup_steps: Optional[int] = field(default=100, metadata={"help": "the number of warmup steps"})
|
||||
weight_decay: Optional[float] = field(default=0.05, metadata={"help": "the weight decay"})
|
||||
optimizer_type: Optional[str] = field(default="paged_adamw_32bit", metadata={"help": "the optimizer type"})
|
||||
|
||||
per_device_train_batch_size: Optional[int] = field(default=4, metadata={"help": "train batch size per device"})
|
||||
per_device_eval_batch_size: Optional[int] = field(default=1, metadata={"help": "eval batch size per device"})
|
||||
gradient_accumulation_steps: Optional[int] = field(
|
||||
default=4, metadata={"help": "the number of gradient accumulation steps"}
|
||||
)
|
||||
gradient_checkpointing: Optional[bool] = field(
|
||||
default=True, metadata={"help": "whether to use gradient checkpointing"}
|
||||
)
|
||||
|
||||
lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"})
|
||||
lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "the lora dropout parameter"})
|
||||
lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"})
|
||||
|
||||
max_prompt_length: Optional[int] = field(default=512, metadata={"help": "the maximum prompt length"})
|
||||
max_length: Optional[int] = field(default=1024, metadata={"help": "the maximum sequence length"})
|
||||
max_steps: Optional[int] = field(default=1000, metadata={"help": "max number of training steps"})
|
||||
logging_steps: Optional[int] = field(default=10, metadata={"help": "the logging frequency"})
|
||||
save_steps: Optional[int] = field(default=100, metadata={"help": "the saving frequency"})
|
||||
eval_steps: Optional[int] = field(default=100, metadata={"help": "the evaluation frequency"})
|
||||
|
||||
output_dir: Optional[str] = field(default="./results", metadata={"help": "the output directory"})
|
||||
log_freq: Optional[int] = field(default=1, metadata={"help": "the logging frequency"})
|
||||
|
||||
# instrumentation
|
||||
sanity_check: Optional[bool] = field(default=False, metadata={"help": "only train on 1000 samples"})
|
||||
report_to: Optional[str] = field(
|
||||
default="wandb",
|
||||
metadata={
|
||||
"help": 'The list of integrations to report the results and logs to. Supported platforms are `"azure_ml"`,'
|
||||
'`"comet_ml"`, `"mlflow"`, `"neptune"`, `"tensorboard"`,`"clearml"` and `"wandb"`. '
|
||||
'Use `"all"` to report to all integrations installed, `"none"` for no integrations.'
|
||||
},
|
||||
)
|
||||
# debug argument for distributed training
|
||||
ignore_bias_buffers: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See"
|
||||
"https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992"
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def get_stack_exchange_paired(
|
||||
data_dir: str = "data/rl",
|
||||
sanity_check: bool = False,
|
||||
cache_dir: str = None,
|
||||
num_proc=24,
|
||||
) -> Dataset:
|
||||
"""Load the stack-exchange-paired dataset from Hugging Face and convert it to the necessary format.
|
||||
|
||||
The dataset is converted to a dictionary with the following structure:
|
||||
{
|
||||
'prompt': List[str],
|
||||
'chosen': List[str],
|
||||
'rejected': List[str],
|
||||
}
|
||||
|
||||
Prompts are structured as follows:
|
||||
"Question: " + <prompt> + "\n\nAnswer: "
|
||||
"""
|
||||
dataset = load_dataset(
|
||||
"lvwerra/stack-exchange-paired",
|
||||
split="train",
|
||||
cache_dir=cache_dir,
|
||||
data_dir=data_dir,
|
||||
)
|
||||
original_columns = dataset.column_names
|
||||
|
||||
if sanity_check:
|
||||
dataset = dataset.select(range(min(len(dataset), 1000)))
|
||||
|
||||
def return_prompt_and_responses(samples) -> Dict[str, str]:
|
||||
return {
|
||||
"prompt": ["Question: " + question + "\n\nAnswer: " for question in samples["question"]],
|
||||
"chosen": samples["response_j"],
|
||||
"rejected": samples["response_k"],
|
||||
}
|
||||
|
||||
return dataset.map(
|
||||
return_prompt_and_responses,
|
||||
batched=True,
|
||||
num_proc=num_proc,
|
||||
remove_columns=original_columns,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = HfArgumentParser(ScriptArguments)
|
||||
script_args = parser.parse_args_into_dataclasses()[0]
|
||||
|
||||
# 1. load a pretrained model
|
||||
model = AutoPeftModelForCausalLM.from_pretrained(
|
||||
script_args.model_name_or_path,
|
||||
low_cpu_mem_usage=True,
|
||||
torch_dtype=torch.float16,
|
||||
load_in_4bit=True,
|
||||
)
|
||||
model.config.use_cache = False
|
||||
|
||||
if script_args.ignore_bias_buffers:
|
||||
# torch distributed hack
|
||||
model._ddp_params_and_buffers_to_ignore = [
|
||||
name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool
|
||||
]
|
||||
|
||||
model_ref = AutoPeftModelForCausalLM.from_pretrained(
|
||||
script_args.model_name_or_path,
|
||||
low_cpu_mem_usage=True,
|
||||
torch_dtype=torch.float16,
|
||||
load_in_4bit=True,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
# 2. Load the Stack-exchange paired dataset
|
||||
train_dataset = get_stack_exchange_paired(data_dir="data/rl", sanity_check=script_args.sanity_check)
|
||||
train_dataset = train_dataset.filter(
|
||||
lambda x: len(x["prompt"]) + len(x["chosen"]) <= script_args.max_length
|
||||
and len(x["prompt"]) + len(x["rejected"]) <= script_args.max_length
|
||||
)
|
||||
|
||||
# 3. Load evaluation dataset
|
||||
eval_dataset = get_stack_exchange_paired(data_dir="data/evaluation", sanity_check=True)
|
||||
eval_dataset = eval_dataset.filter(
|
||||
lambda x: len(x["prompt"]) + len(x["chosen"]) <= script_args.max_length
|
||||
and len(x["prompt"]) + len(x["rejected"]) <= script_args.max_length
|
||||
)
|
||||
|
||||
# 4. initialize training arguments:
|
||||
training_args = TrainingArguments(
|
||||
per_device_train_batch_size=script_args.per_device_train_batch_size,
|
||||
per_device_eval_batch_size=script_args.per_device_eval_batch_size,
|
||||
max_steps=script_args.max_steps,
|
||||
logging_steps=script_args.logging_steps,
|
||||
save_steps=script_args.save_steps,
|
||||
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
|
||||
gradient_checkpointing=script_args.gradient_checkpointing,
|
||||
learning_rate=script_args.learning_rate,
|
||||
evaluation_strategy="steps",
|
||||
eval_steps=script_args.eval_steps,
|
||||
output_dir=script_args.output_dir,
|
||||
report_to=script_args.report_to,
|
||||
lr_scheduler_type=script_args.lr_scheduler_type,
|
||||
warmup_steps=script_args.warmup_steps,
|
||||
optim=script_args.optimizer_type,
|
||||
bf16=True,
|
||||
remove_unused_columns=False,
|
||||
run_name="dpo_llama2",
|
||||
)
|
||||
|
||||
peft_config = LoraConfig(
|
||||
r=script_args.lora_r,
|
||||
lora_alpha=script_args.lora_alpha,
|
||||
lora_dropout=script_args.lora_dropout,
|
||||
target_modules=[
|
||||
"q_proj",
|
||||
"v_proj",
|
||||
"k_proj",
|
||||
"out_proj",
|
||||
"fc_in",
|
||||
"fc_out",
|
||||
"wte",
|
||||
],
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
|
||||
# 5. initialize the DPO trainer
|
||||
dpo_trainer = DPOTrainer(
|
||||
model,
|
||||
model_ref,
|
||||
args=training_args,
|
||||
beta=script_args.beta,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
tokenizer=tokenizer,
|
||||
peft_config=peft_config,
|
||||
max_prompt_length=script_args.max_prompt_length,
|
||||
max_length=script_args.max_length,
|
||||
)
|
||||
|
||||
# 6. train
|
||||
dpo_trainer.train()
|
||||
dpo_trainer.save_model(script_args.output_dir)
|
||||
|
||||
# 7. save
|
||||
output_dir = os.path.join(script_args.output_dir, "final_checkpoint")
|
||||
dpo_trainer.model.save_pretrained(output_dir)
|
@ -0,0 +1,7 @@
|
||||
transformers
|
||||
trl
|
||||
peft
|
||||
accelerate
|
||||
datasets
|
||||
bitsandbytes
|
||||
wandb
|
216
examples/research_projects/stack_llama_2/scripts/sft_llama2.py
Normal file
216
examples/research_projects/stack_llama_2/scripts/sft_llama2.py
Normal file
@ -0,0 +1,216 @@
|
||||
# Fine-Tune Llama2-7b on SE paired dataset
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from peft import AutoPeftModelForCausalLM, LoraConfig
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, HfArgumentParser, TrainingArguments
|
||||
|
||||
from trl import SFTTrainer
|
||||
from trl.trainer import ConstantLengthDataset
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
model_name: Optional[str] = field(default="meta-llama/Llama-2-7b-hf", metadata={"help": "the model name"})
|
||||
log_with: Optional[str] = field(default="wandb", metadata={"help": "use 'wandb' to log with wandb"})
|
||||
|
||||
dataset_name: Optional[str] = field(default="lvwerra/stack-exchange-paired", metadata={"help": "the dataset name"})
|
||||
subset: Optional[str] = field(default="data/finetune", metadata={"help": "the subset to use"})
|
||||
split: Optional[str] = field(default="train", metadata={"help": "the split to use"})
|
||||
size_valid_set: Optional[int] = field(default=4000, metadata={"help": "the size of the validation set"})
|
||||
streaming: Optional[bool] = field(default=True, metadata={"help": "whether to stream the dataset"})
|
||||
shuffle_buffer: Optional[int] = field(default=5000, metadata={"help": "the shuffle buffer size"})
|
||||
seq_length: Optional[int] = field(default=1024, metadata={"help": "the sequence length"})
|
||||
num_workers: Optional[int] = field(default=4, metadata={"help": "the number of workers"})
|
||||
|
||||
max_steps: Optional[int] = field(default=500, metadata={"help": "the maximum number of sgd steps"})
|
||||
logging_steps: Optional[int] = field(default=10, metadata={"help": "the logging frequency"})
|
||||
save_steps: Optional[int] = field(default=10, metadata={"help": "the saving frequency"})
|
||||
per_device_train_batch_size: Optional[int] = field(default=4, metadata={"help": "the per device train batch size"})
|
||||
per_device_eval_batch_size: Optional[int] = field(default=1, metadata={"help": "the per device eval batch size"})
|
||||
gradient_accumulation_steps: Optional[int] = field(default=2, metadata={"help": "the gradient accumulation steps"})
|
||||
gradient_checkpointing: Optional[bool] = field(
|
||||
default=True, metadata={"help": "whether to use gradient checkpointing"}
|
||||
)
|
||||
group_by_length: Optional[bool] = field(default=False, metadata={"help": "whether to group by length"})
|
||||
packing: Optional[bool] = field(default=True, metadata={"help": "whether to use packing for SFTTrainer"})
|
||||
|
||||
lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"})
|
||||
lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "the lora dropout parameter"})
|
||||
lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"})
|
||||
|
||||
learning_rate: Optional[float] = field(default=1e-4, metadata={"help": "the learning rate"})
|
||||
lr_scheduler_type: Optional[str] = field(default="cosine", metadata={"help": "the lr scheduler type"})
|
||||
num_warmup_steps: Optional[int] = field(default=100, metadata={"help": "the number of warmup steps"})
|
||||
weight_decay: Optional[float] = field(default=0.05, metadata={"help": "the weight decay"})
|
||||
optimizer_type: Optional[str] = field(default="paged_adamw_32bit", metadata={"help": "the optimizer type"})
|
||||
|
||||
output_dir: Optional[str] = field(default="./results", metadata={"help": "the output directory"})
|
||||
log_freq: Optional[int] = field(default=1, metadata={"help": "the logging frequency"})
|
||||
|
||||
|
||||
parser = HfArgumentParser(ScriptArguments)
|
||||
script_args = parser.parse_args_into_dataclasses()[0]
|
||||
|
||||
if script_args.group_by_length and script_args.packing:
|
||||
raise ValueError("Cannot use both packing and group by length")
|
||||
|
||||
|
||||
def chars_token_ratio(dataset, tokenizer, nb_examples=400):
|
||||
"""
|
||||
Estimate the average number of characters per token in the dataset.
|
||||
"""
|
||||
total_characters, total_tokens = 0, 0
|
||||
for _, example in tqdm(zip(range(nb_examples), iter(dataset)), total=nb_examples):
|
||||
text = prepare_sample_text(example)
|
||||
total_characters += len(text)
|
||||
if tokenizer.is_fast:
|
||||
total_tokens += len(tokenizer(text).tokens())
|
||||
else:
|
||||
total_tokens += len(tokenizer.tokenize(text))
|
||||
|
||||
return total_characters / total_tokens
|
||||
|
||||
|
||||
def print_trainable_parameters(model):
|
||||
"""
|
||||
Prints the number of trainable parameters in the model.
|
||||
"""
|
||||
trainable_params = 0
|
||||
all_param = 0
|
||||
for _, param in model.named_parameters():
|
||||
all_param += param.numel()
|
||||
if param.requires_grad:
|
||||
trainable_params += param.numel()
|
||||
print(
|
||||
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
|
||||
)
|
||||
|
||||
|
||||
def prepare_sample_text(example):
|
||||
"""Prepare the text from a sample of the dataset."""
|
||||
text = f"Question: {example['question']}\n\nAnswer: {example['response_j']}"
|
||||
return text
|
||||
|
||||
|
||||
def create_datasets(tokenizer, args):
|
||||
dataset = load_dataset(
|
||||
args.dataset_name,
|
||||
data_dir=args.subset,
|
||||
split=args.split,
|
||||
use_auth_token=True,
|
||||
num_proc=args.num_workers if not args.streaming else None,
|
||||
streaming=args.streaming,
|
||||
)
|
||||
if args.streaming:
|
||||
print("Loading the dataset in streaming mode")
|
||||
valid_data = dataset.take(args.size_valid_set)
|
||||
train_data = dataset.skip(args.size_valid_set)
|
||||
train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=None)
|
||||
else:
|
||||
dataset = dataset.train_test_split(test_size=0.005, seed=None)
|
||||
train_data = dataset["train"]
|
||||
valid_data = dataset["test"]
|
||||
print(f"Size of the train set: {len(train_data)}. Size of the validation set: {len(valid_data)}")
|
||||
|
||||
chars_per_token = chars_token_ratio(train_data, tokenizer)
|
||||
print(f"The character to token ratio of the dataset is: {chars_per_token:.2f}")
|
||||
|
||||
train_dataset = ConstantLengthDataset(
|
||||
tokenizer,
|
||||
train_data,
|
||||
formatting_func=prepare_sample_text,
|
||||
infinite=True,
|
||||
seq_length=args.seq_length,
|
||||
chars_per_token=chars_per_token,
|
||||
)
|
||||
valid_dataset = ConstantLengthDataset(
|
||||
tokenizer,
|
||||
valid_data,
|
||||
formatting_func=prepare_sample_text,
|
||||
infinite=False,
|
||||
seq_length=args.seq_length,
|
||||
chars_per_token=chars_per_token,
|
||||
)
|
||||
return train_dataset, valid_dataset
|
||||
|
||||
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
base_model = AutoModelForCausalLM.from_pretrained(
|
||||
script_args.model_name,
|
||||
quantization_config=bnb_config,
|
||||
device_map={"": 0},
|
||||
trust_remote_code=True,
|
||||
use_auth_token=True,
|
||||
)
|
||||
base_model.config.use_cache = False
|
||||
|
||||
peft_config = LoraConfig(
|
||||
r=script_args.lora_r,
|
||||
lora_alpha=script_args.lora_alpha,
|
||||
lora_dropout=script_args.lora_dropout,
|
||||
target_modules=["q_proj", "v_proj"],
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(script_args.model_name, trust_remote_code=True)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
tokenizer.padding_side = "right" # Fix weird overflow issue with fp16 training
|
||||
|
||||
|
||||
training_args = TrainingArguments(
|
||||
output_dir=script_args.output_dir,
|
||||
per_device_train_batch_size=script_args.per_device_train_batch_size,
|
||||
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
|
||||
per_device_eval_batch_size=script_args.per_device_eval_batch_size,
|
||||
learning_rate=script_args.learning_rate,
|
||||
logging_steps=script_args.logging_steps,
|
||||
max_steps=script_args.max_steps,
|
||||
report_to=script_args.log_with,
|
||||
save_steps=script_args.save_steps,
|
||||
group_by_length=script_args.group_by_length,
|
||||
lr_scheduler_type=script_args.lr_scheduler_type,
|
||||
warmup_steps=script_args.num_warmup_steps,
|
||||
optim=script_args.optimizer_type,
|
||||
bf16=True,
|
||||
remove_unused_columns=False,
|
||||
run_name="sft_llama2",
|
||||
)
|
||||
|
||||
train_dataset, eval_dataset = create_datasets(tokenizer, script_args)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=base_model,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
peft_config=peft_config,
|
||||
packing=script_args.packing,
|
||||
max_seq_length=None,
|
||||
tokenizer=tokenizer,
|
||||
args=training_args,
|
||||
)
|
||||
trainer.train()
|
||||
trainer.save_model(script_args.output_dir)
|
||||
|
||||
output_dir = os.path.join(script_args.output_dir, "final_checkpoint")
|
||||
trainer.model.save_pretrained(output_dir)
|
||||
|
||||
# Free memory for merging weights
|
||||
del base_model
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
model = AutoPeftModelForCausalLM.from_pretrained(output_dir, device_map="auto", torch_dtype=torch.bfloat16)
|
||||
model = model.merge_and_unload()
|
||||
|
||||
output_merged_dir = os.path.join(script_args.output_dir, "final_merged_checkpoint")
|
||||
model.save_pretrained(output_merged_dir, safe_serialization=True)
|
119
examples/research_projects/tools/calculator.py
Normal file
119
examples/research_projects/tools/calculator.py
Normal file
@ -0,0 +1,119 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import AutoTokenizer, load_tool
|
||||
|
||||
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, TextEnvironment
|
||||
|
||||
|
||||
def generate_data(n):
|
||||
"""Generate random arithmetic tasks and answers."""
|
||||
tasks, answers = [], []
|
||||
for _ in range(n):
|
||||
a = np.random.randint(0, 50)
|
||||
b = np.random.randint(0, 50)
|
||||
op = np.random.choice(["-", "+", "*"])
|
||||
tasks.append(f"\n\nWhat is {a} {op} {b}?")
|
||||
if op == "-":
|
||||
answers.append(a - b)
|
||||
elif op == "+":
|
||||
answers.append(a + b)
|
||||
else:
|
||||
answers.append(a * b)
|
||||
return tasks, answers
|
||||
|
||||
|
||||
def exact_match_reward(responses, answers=None):
|
||||
"""Reward if generated response contains correct answer."""
|
||||
rewards = []
|
||||
pattern = r"Result\s*=\s*(-?\d+(?:\.\d+)?)\s*<submit>" # generated by chatGPT
|
||||
for response, answer in zip(responses, answers):
|
||||
reward = 0.0
|
||||
predicted_number = None
|
||||
match_pattern = re.findall(pattern, response)
|
||||
if match_pattern:
|
||||
predicted_number = float(match_pattern[0])
|
||||
if predicted_number is not None:
|
||||
if np.abs(predicted_number - answer) < 0.01:
|
||||
reward += 1.0
|
||||
rewards.append(torch.tensor(reward))
|
||||
return rewards
|
||||
|
||||
|
||||
# set up models
|
||||
model_id = "gpt2"
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained(model_id)
|
||||
model_ref = AutoModelForCausalLMWithValueHead.from_pretrained(model_id)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
# system prompt
|
||||
prompt = """\
|
||||
What is 13-3?
|
||||
|
||||
<request><SimpleCalculatorTool>13-3<call>10.0<response>
|
||||
|
||||
Result=10<submit>
|
||||
|
||||
What is 4*3?
|
||||
|
||||
<request><SimpleCalculatorTool>4*3<call>12.0<response>
|
||||
|
||||
Result=12<submit>"""
|
||||
|
||||
generation_kwargs = {
|
||||
"min_length": -1,
|
||||
"top_k": 0.0,
|
||||
"top_p": 1.0,
|
||||
"do_sample": True,
|
||||
"pad_token_id": tokenizer.eos_token_id,
|
||||
"eos_token_id": -1,
|
||||
"max_new_tokens": 32,
|
||||
}
|
||||
|
||||
# trainer
|
||||
ppo_config = PPOConfig(
|
||||
batch_size=256,
|
||||
learning_rate=1.41e-5,
|
||||
mini_batch_size=64,
|
||||
log_with="wandb",
|
||||
)
|
||||
ppo_trainer = PPOTrainer(ppo_config, model, model_ref, tokenizer)
|
||||
|
||||
# text env
|
||||
text_env = TextEnvironment(
|
||||
model,
|
||||
tokenizer,
|
||||
{"SimpleCalculatorTool": load_tool("ybelkada/simple-calculator")},
|
||||
exact_match_reward,
|
||||
prompt,
|
||||
generation_kwargs=generation_kwargs,
|
||||
)
|
||||
|
||||
# main training loop
|
||||
for step in range(100):
|
||||
tasks, answers = generate_data(ppo_config.batch_size)
|
||||
queries, responses, masks, rewards, histories = text_env.run(tasks, answers=answers)
|
||||
train_stats = ppo_trainer.step(queries, responses, rewards, masks)
|
||||
|
||||
response_texts = [tokenizer.decode(response) for response in responses]
|
||||
query_texts = [tokenizer.decode(query) for query in queries]
|
||||
texts = {"query": [qt.split("<submit>")[-1].strip() for qt in query_texts], "response": response_texts}
|
||||
ppo_trainer.log_stats(train_stats, texts, rewards)
|
||||
ppo_trainer.save_pretrained(model_id + "-calculator")
|
194
examples/research_projects/tools/python_interpreter.py
Normal file
194
examples/research_projects/tools/python_interpreter.py
Normal file
@ -0,0 +1,194 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from peft import LoraConfig
|
||||
from transformers import AutoTokenizer, HfArgumentParser, load_tool
|
||||
|
||||
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, TextEnvironment
|
||||
|
||||
|
||||
os.environ["HF_ALLOW_CODE_EVAL"] = "1"
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
model_name: Optional[str] = field(default="bigcode/starcoderbase", metadata={"help": "the model name"})
|
||||
learning_rate: Optional[float] = field(default=1e-5, metadata={"help": "the learning rate"})
|
||||
mini_batch_size: Optional[int] = field(default=1, metadata={"help": "the PPO minibatch size"})
|
||||
batch_size: Optional[int] = field(default=32, metadata={"help": "the batch size"})
|
||||
gradient_accumulation_steps: Optional[int] = field(
|
||||
default=16, metadata={"help": "the number of gradient accumulation steps"}
|
||||
)
|
||||
max_new_tokens: Optional[int] = field(default=256, metadata={"help": "max number of generated tokens per turn"})
|
||||
ppo_epochs: Optional[int] = field(default=1, metadata={"help": "max number of ppo epochs"})
|
||||
n_epochs: Optional[int] = field(default=32, metadata={"help": "max number of ppo epochs"})
|
||||
|
||||
|
||||
parser = HfArgumentParser(ScriptArguments)
|
||||
args = parser.parse_args_into_dataclasses()[0]
|
||||
|
||||
|
||||
def exact_match_reward(responses, answers=None):
|
||||
"""Reward if generated response contains correct answer."""
|
||||
rewards = []
|
||||
pattern = r"Result\s*=\s*(-?\d+(?:\.\d+)?)\s*<submit>" # generated by chatGPT
|
||||
for response, answer in zip(responses, answers):
|
||||
reward = 0.0
|
||||
try:
|
||||
predicted_number = None
|
||||
match_pattern = re.findall(pattern, response)
|
||||
if match_pattern:
|
||||
predicted_number = float(match_pattern[0])
|
||||
if predicted_number is not None:
|
||||
if np.abs((predicted_number - float(answer))) < 0.1:
|
||||
reward += 1.0
|
||||
except: # noqa
|
||||
pass
|
||||
rewards.append(torch.tensor(reward))
|
||||
return rewards
|
||||
|
||||
|
||||
def evaluate(test_dataloader, text_env, ppo_trainer):
|
||||
test_rewards = []
|
||||
for test_batch in test_dataloader:
|
||||
_, _, _, rewards, _ = text_env.run(test_batch["query"], answers=test_batch["answer"])
|
||||
test_rewards.extend(rewards)
|
||||
test_rewards = ppo_trainer.accelerator.gather_for_metrics(
|
||||
torch.stack(test_rewards).to(ppo_trainer.accelerator.device)
|
||||
)
|
||||
return test_rewards.mean()
|
||||
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
lora_dropout=0.05,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
target_modules=["c_proj", "c_attn", "q_attn"],
|
||||
)
|
||||
|
||||
# set up models
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
args.model_name,
|
||||
use_auth_token=True,
|
||||
load_in_4bit=True,
|
||||
peft_config=lora_config,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_auth_token=True)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
ds = load_dataset("gsm8k", "main", split="train")
|
||||
ds = ds.rename_columns({"question": "query"})
|
||||
ds = ds.map(lambda x: {"answer": x["answer"].split("#### ")[1]})
|
||||
ds = ds.select(range(1, len(ds))) # skip the first sample which is used in prompt
|
||||
|
||||
ds_test = load_dataset("gsm8k", "main", split="test")
|
||||
ds_test = ds_test.rename_columns({"question": "query"})
|
||||
ds_test = ds_test.map(lambda x: {"answer": x["answer"].split("#### ")[1]})
|
||||
|
||||
test_dataloader = torch.utils.data.DataLoader(ds_test, batch_size=args.batch_size)
|
||||
|
||||
# prompt
|
||||
prompt = """\
|
||||
Example of using a Python API to solve math questions.
|
||||
|
||||
Q: Olivia has $23. She bought five bagels for $3 each. How much money does she have left?
|
||||
|
||||
<request><PythonInterpreter>
|
||||
def solution():
|
||||
money_initial = 23
|
||||
bagels = 5
|
||||
bagel_cost = 3
|
||||
money_spent = bagels * bagel_cost
|
||||
money_left = money_initial - money_spent
|
||||
result = money_left
|
||||
return result
|
||||
print(solution())
|
||||
<call>72<response>
|
||||
|
||||
Result = 72 <submit>
|
||||
|
||||
Q: """
|
||||
|
||||
generation_kwargs = {
|
||||
"min_length": -1,
|
||||
"top_k": 0.0,
|
||||
"top_p": 1.0,
|
||||
"do_sample": True,
|
||||
"pad_token_id": tokenizer.eos_token_id,
|
||||
"eos_token_id": -1,
|
||||
"max_new_tokens": args.max_new_tokens,
|
||||
}
|
||||
|
||||
# trainer
|
||||
ppo_config = PPOConfig(
|
||||
batch_size=args.batch_size,
|
||||
learning_rate=args.learning_rate,
|
||||
mini_batch_size=args.mini_batch_size,
|
||||
ppo_epochs=args.ppo_epochs,
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
log_with="wandb",
|
||||
tracker_project_name="trl-gsm8k",
|
||||
remove_unused_columns=False,
|
||||
optimize_cuda_cache=True,
|
||||
)
|
||||
|
||||
ppo_trainer = PPOTrainer(config=ppo_config, model=model, tokenizer=tokenizer, dataset=ds)
|
||||
test_dataloader = ppo_trainer.accelerator.prepare(test_dataloader)
|
||||
|
||||
# text env
|
||||
text_env = TextEnvironment(
|
||||
model,
|
||||
tokenizer,
|
||||
[load_tool("lvwerra/python-interpreter")],
|
||||
exact_match_reward,
|
||||
prompt,
|
||||
max_turns=2,
|
||||
generation_kwargs=generation_kwargs,
|
||||
)
|
||||
|
||||
# main training loop
|
||||
for epoch in range(args.n_epochs):
|
||||
for step, batch in enumerate(ppo_trainer.dataloader):
|
||||
if (step == 0) and (epoch % 4 == 0): # evaluate every 4 epochs
|
||||
reward_mean_test = evaluate(test_dataloader, text_env, ppo_trainer)
|
||||
else:
|
||||
reward_mean_test = None
|
||||
|
||||
queries, responses, masks, rewards, histories = text_env.run(batch["query"], answers=batch["answer"])
|
||||
train_stats = ppo_trainer.step(queries, responses, rewards, masks)
|
||||
|
||||
# logging
|
||||
if reward_mean_test is not None:
|
||||
train_stats["env/reward_mean_test"] = reward_mean_test
|
||||
texts = {
|
||||
"query": batch["query"],
|
||||
"response": [tokenizer.decode(response) for response in responses],
|
||||
"answer": batch["answer"],
|
||||
}
|
||||
ppo_trainer.log_stats(train_stats, texts, rewards)
|
||||
|
||||
reward_mean_test = evaluate(test_dataloader, text_env, ppo_trainer)
|
||||
ppo_trainer.save_pretrained(f"model/{args.model_name}-gsm8k")
|
189
examples/research_projects/tools/triviaqa.py
Normal file
189
examples/research_projects/tools/triviaqa.py
Normal file
@ -0,0 +1,189 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from peft import LoraConfig
|
||||
from transformers import AutoTokenizer, HfArgumentParser, load_tool
|
||||
|
||||
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer, TextEnvironment
|
||||
|
||||
|
||||
os.environ["HF_ALLOW_CODE_EVAL"] = "1"
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
model_name: Optional[str] = field(default="bigcode/starcoderbase", metadata={"help": "the model name"})
|
||||
log_with: Optional[str] = field(default=None, metadata={"help": "use 'wandb' to log with wandb"})
|
||||
learning_rate: Optional[float] = field(default=1e-5, metadata={"help": "the learning rate"})
|
||||
mini_batch_size: Optional[int] = field(default=1, metadata={"help": "the PPO minibatch size"})
|
||||
batch_size: Optional[int] = field(default=32, metadata={"help": "the batch size"})
|
||||
gradient_accumulation_steps: Optional[int] = field(
|
||||
default=16, metadata={"help": "the number of gradient accumulation steps"}
|
||||
)
|
||||
max_new_tokens: Optional[int] = field(default=256, metadata={"help": "max number of generated tokens per turn"})
|
||||
ppo_epochs: Optional[int] = field(default=1, metadata={"help": "max number of ppo epochs"})
|
||||
iterations: Optional[int] = field(default=1000, metadata={"help": "the number of iterations"})
|
||||
seed: Optional[int] = field(default=0, metadata={"help": "the random seed"})
|
||||
|
||||
|
||||
parser = HfArgumentParser(ScriptArguments)
|
||||
args = parser.parse_args_into_dataclasses()[0]
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
lora_dropout=0.05,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
target_modules=["c_proj", "c_attn", "q_attn"],
|
||||
)
|
||||
|
||||
# set up models
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
args.model_name,
|
||||
use_auth_token=True,
|
||||
trust_remote_code=True,
|
||||
load_in_4bit=True,
|
||||
peft_config=lora_config,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_auth_token=True)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
# system prompt
|
||||
prompt = """\
|
||||
Answer the following question:
|
||||
|
||||
Q: In which branch of the arts is Patricia Neary famous?
|
||||
A: Ballets
|
||||
A2: <request><Wiki>Patricia Neary<call>Patricia Neary (born October 27, 1942) is an American ballerina, choreographer and ballet director, who has been particularly active in Switzerland. She has also been a highly successful ambassador for the Balanchine Trust, bringing George Balanchine's ballets to 60 cities around the globe.<response>
|
||||
Result=Ballets<submit>
|
||||
|
||||
Q: Who won Super Bowl XX?
|
||||
A: Chicago Bears
|
||||
A2: <request><Wiki>Super Bowl XX<call>Super Bowl XX was an American football game between the National Football Conference (NFC) champion Chicago Bears and the American Football Conference (AFC) champion New England Patriots to decide the National Football League (NFL) champion for the 1985 season. The Bears defeated the Patriots by the score of 46–10, capturing their first NFL championship (and Chicago's first overall sports victory) since 1963, three years prior to the birth of the Super Bowl. Super Bowl XX was played on January 26, 1986 at the Louisiana Superdome in New Orleans.<response>
|
||||
Result=Chicago Bears<submit>
|
||||
|
||||
Q: """
|
||||
|
||||
generation_kwargs = {
|
||||
"min_length": -1,
|
||||
"top_k": 0.0,
|
||||
"top_p": 1.0,
|
||||
"do_sample": True,
|
||||
"pad_token_id": tokenizer.eos_token_id,
|
||||
"eos_token_id": -1,
|
||||
"max_new_tokens": args.max_new_tokens,
|
||||
}
|
||||
|
||||
# trainer
|
||||
config = PPOConfig(
|
||||
batch_size=args.batch_size,
|
||||
model_name=args.model_name,
|
||||
learning_rate=args.learning_rate,
|
||||
log_with=args.log_with,
|
||||
mini_batch_size=args.mini_batch_size,
|
||||
ppo_epochs=args.ppo_epochs,
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
seed=args.seed,
|
||||
optimize_cuda_cache=True,
|
||||
)
|
||||
ppo_trainer = PPOTrainer(config=config, model=model, tokenizer=tokenizer)
|
||||
dataset = load_dataset("trivia_qa", "rc", split="train")
|
||||
local_seed = args.seed + ppo_trainer.accelerator.process_index * 100003 # Prime
|
||||
dataset = dataset.shuffle(local_seed)
|
||||
|
||||
|
||||
def data_generator():
|
||||
for i in range(len(dataset)):
|
||||
yield dataset[i]["question"], [item for item in dataset[i]["answer"]["normalized_aliases"]]
|
||||
|
||||
|
||||
gen = data_generator()
|
||||
gen = iter(gen)
|
||||
|
||||
|
||||
def generate_data(n):
|
||||
tasks, answers = [], []
|
||||
for i in range(n):
|
||||
q, a = next(gen)
|
||||
tasks.append(q)
|
||||
answers.append(a)
|
||||
return tasks, answers
|
||||
|
||||
|
||||
def exact_match_reward(responses, answers=None):
|
||||
"""Reward if generated response contains correct answer."""
|
||||
rewards = []
|
||||
for response, answer in zip(responses, answers):
|
||||
reward = 0.0
|
||||
for a in answer:
|
||||
if a.lower() in response.lower():
|
||||
reward += 1.0
|
||||
break
|
||||
rewards.append(torch.tensor(reward))
|
||||
return rewards
|
||||
|
||||
|
||||
# text env
|
||||
tool = load_tool("vwxyzjn/pyserini-wikipedia-kilt-doc")
|
||||
# limit the amount if tokens
|
||||
tool_fn = lambda x: tool(x).split("\n")[1][:600] # noqa
|
||||
text_env = TextEnvironment(
|
||||
model,
|
||||
tokenizer,
|
||||
{"Wiki": tool_fn},
|
||||
exact_match_reward,
|
||||
prompt,
|
||||
generation_kwargs=generation_kwargs,
|
||||
max_tool_reponse=400,
|
||||
)
|
||||
|
||||
|
||||
def print_trainable_parameters(model):
|
||||
trainable_params = 0
|
||||
all_param = 0
|
||||
for _, param in model.named_parameters():
|
||||
all_param += param.numel()
|
||||
if param.requires_grad:
|
||||
trainable_params += param.numel()
|
||||
print(
|
||||
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
|
||||
)
|
||||
|
||||
|
||||
print_trainable_parameters(model)
|
||||
# main training loop
|
||||
for i in range(args.iterations):
|
||||
tasks, answers = generate_data(config.batch_size)
|
||||
queries, responses, masks, rewards, histories = text_env.run(tasks, answers=answers)
|
||||
train_stats = ppo_trainer.step(queries, responses, rewards, masks)
|
||||
response_texts = [tokenizer.decode(response) for response in responses]
|
||||
query_texts = [tokenizer.decode(query) for query in queries]
|
||||
texts = {
|
||||
"query": [qt.split("<submit>")[-1].strip() for qt in query_texts],
|
||||
"response": response_texts,
|
||||
"answer": [", ".join(item) for item in answers],
|
||||
}
|
||||
all_rewards = ppo_trainer.accelerator.gather(torch.tensor(rewards, device=ppo_trainer.accelerator.device))
|
||||
ppo_trainer.log_stats(train_stats, texts, [item for item in all_rewards])
|
||||
if i % 100 == 0:
|
||||
ppo_trainer.save_pretrained(f"models/{args.model_name}_{args.seed}_{i}_triviaqa")
|
@ -1,8 +1,9 @@
|
||||
# Maintained scripts
|
||||
|
||||
This folder shows multiple ways to use the objects from TRL such as `SFTTrainer`, `RewardTrainer` and `PPOTrainer` in different scenarios.
|
||||
This folder shows multiple ways to use the objects from TRL such as `SFTTrainer`, `RewardTrainer`, `DDPOTrainer` and `PPOTrainer` in different scenarios.
|
||||
|
||||
- `sft_trainer.py`: This script shows how to use the SFTTrainer to fine tune a model or adapters into a target dataset.
|
||||
- `reward_trainer.py`: This script shows how to use the RewardTrainer to train a reward model on your own dataset.
|
||||
- `sentiment_tuning.py`: This script shows how to use the PPOTrainer to fine-tune a sentiment analysis model using IMDB dataset
|
||||
- `multi_adapter_rl.py`: This script shows how to use the PPOTrainer to train a single base model with multiple adapters. This scripts requires you to run the example script with the reward model training beforehand.
|
||||
- `multi_adapter_rl.py`: This script shows how to use the PPOTrainer to train a single base model with multiple adapters. This scripts requires you to run the example script with the reward model training beforehand.
|
||||
- `stable_diffusion_tuning_example.py`: This script shows to use DDPOTrainer to fine-tune a stable diffusion model using reinforcement learning.
|
||||
|
150
examples/scripts/multi_adapter_rl_v2.py
Normal file
150
examples/scripts/multi_adapter_rl_v2.py
Normal file
@ -0,0 +1,150 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from peft import LoraConfig
|
||||
from tqdm import tqdm
|
||||
from transformers import BitsAndBytesConfig, HfArgumentParser, LlamaTokenizer
|
||||
|
||||
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer
|
||||
from trl.core import LengthSampler
|
||||
|
||||
|
||||
input_min_text_length = 6
|
||||
input_max_text_length = 12
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
"""
|
||||
The name of the Casual LM model we wish to fine with PPO
|
||||
"""
|
||||
|
||||
model_name: Optional[str] = field(default="huggyllama/llama-7b", metadata={"help": "the model name"})
|
||||
dataset_name: Optional[str] = field(default="Anthropic/hh-rlhf", metadata={"help": "the dataset name"})
|
||||
rm_adapter: Optional[str] = field(
|
||||
default="trl-lib/llama-7b-hh-rm-adapter", metadata={"help": "the rm adapter name"}
|
||||
)
|
||||
log_with: Optional[str] = field(default=None, metadata={"help": "use 'wandb' to log with wandb"})
|
||||
use_safetensors: Optional[bool] = field(default=False, metadata={"help": "Use safetensors"})
|
||||
seed: Optional[int] = field(default=0, metadata={"help": "the random seed"})
|
||||
use_score_scaling: Optional[bool] = field(default=False, metadata={"help": "Use score scaling"})
|
||||
use_score_norm: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Use score normalization. Only applicable if use_score_scaling is True"}
|
||||
)
|
||||
score_clip: Optional[float] = field(default=None, metadata={"help": "Score clipping"})
|
||||
|
||||
|
||||
parser = HfArgumentParser(ScriptArguments)
|
||||
script_args = parser.parse_args_into_dataclasses()[0]
|
||||
|
||||
|
||||
def create_and_prepare_dataset(tokenizer):
|
||||
dataset = load_dataset(script_args.dataset_name, split="train[:1%]")
|
||||
|
||||
input_size = LengthSampler(input_min_text_length, input_max_text_length)
|
||||
|
||||
def tokenize(example):
|
||||
text_size = input_size()
|
||||
example["input_ids"] = tokenizer.encode(example["chosen"])[:text_size]
|
||||
example["query"] = tokenizer.decode(example["input_ids"])
|
||||
return example
|
||||
|
||||
dataset = dataset.map(tokenize, batched=False)
|
||||
dataset.set_format("torch")
|
||||
return dataset
|
||||
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
lora_dropout=0.05,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
nf4_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16
|
||||
)
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
script_args.model_name,
|
||||
device_map={"": 0},
|
||||
peft_config=lora_config,
|
||||
quantization_config=nf4_config,
|
||||
reward_adapter=script_args.rm_adapter,
|
||||
use_safetensors=script_args.use_safetensors,
|
||||
)
|
||||
tokenizer = LlamaTokenizer.from_pretrained(script_args.model_name)
|
||||
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
dataset = create_and_prepare_dataset(tokenizer)
|
||||
|
||||
|
||||
def collator(data):
|
||||
return dict((key, [d[key] for d in data]) for key in data[0])
|
||||
|
||||
|
||||
config = PPOConfig(
|
||||
model_name=script_args.model_name,
|
||||
log_with=script_args.log_with,
|
||||
learning_rate=1e-5,
|
||||
batch_size=8,
|
||||
mini_batch_size=2,
|
||||
gradient_accumulation_steps=2,
|
||||
optimize_cuda_cache=True,
|
||||
seed=script_args.seed,
|
||||
use_score_scaling=script_args.use_score_scaling,
|
||||
use_score_norm=script_args.use_score_norm,
|
||||
score_clip=script_args.score_clip,
|
||||
)
|
||||
|
||||
ppo_trainer = PPOTrainer(
|
||||
config,
|
||||
model,
|
||||
ref_model=None,
|
||||
tokenizer=tokenizer,
|
||||
dataset=dataset,
|
||||
data_collator=collator,
|
||||
)
|
||||
|
||||
generation_kwargs = {
|
||||
"top_k": 0.0,
|
||||
"top_p": 0.9,
|
||||
"do_sample": True,
|
||||
"pad_token_id": tokenizer.pad_token_id,
|
||||
}
|
||||
|
||||
for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
|
||||
question_tensors = batch["input_ids"]
|
||||
|
||||
response_tensors = ppo_trainer.generate(
|
||||
question_tensors,
|
||||
return_prompt=False,
|
||||
**generation_kwargs,
|
||||
)
|
||||
batch["response"] = tokenizer.batch_decode(response_tensors, skip_special_tokens=True)
|
||||
|
||||
# Compute reward score
|
||||
texts = [q + r for q, r in zip(batch["query"], batch["response"])]
|
||||
inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt").to(ppo_trainer.accelerator.device)
|
||||
raw_rewards = ppo_trainer.model.compute_reward_score(**inputs)
|
||||
rewards = [raw_rewards[i, -1, 1] for i in range(len(raw_rewards))] # take last token
|
||||
|
||||
# Run PPO step
|
||||
stats = ppo_trainer.step(question_tensors, response_tensors, rewards)
|
||||
ppo_trainer.log_stats(stats, batch, rewards)
|
@ -45,7 +45,6 @@ class ScriptArguments:
|
||||
default=1, metadata={"help": "the number of gradient accumulation steps"}
|
||||
)
|
||||
early_stopping: Optional[bool] = field(default=False, metadata={"help": "whether to early stop"})
|
||||
target_kl: Optional[float] = field(default=6, metadata={"help": "kl target for early stopping"})
|
||||
use_peft: Optional[bool] = field(default=False, metadata={"help": "whether to use peft"})
|
||||
use_seq2seq: Optional[bool] = field(default=False, metadata={"help": "whether to use seq2seq models"})
|
||||
kl_penalty: Optional[str] = field(
|
||||
@ -56,6 +55,11 @@ class ScriptArguments:
|
||||
)
|
||||
target_kl: Optional[float] = field(default=0.1, metadata={"help": "kl target for early stopping"})
|
||||
seed: Optional[int] = field(default=0, metadata={"help": "the random seed"})
|
||||
use_score_scaling: Optional[bool] = field(default=False, metadata={"help": "Use score scaling"})
|
||||
use_score_norm: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Use score normalization. Only applicable if use_score_scaling is True"}
|
||||
)
|
||||
score_clip: Optional[float] = field(default=None, metadata={"help": "Score clipping"})
|
||||
|
||||
|
||||
parser = HfArgumentParser(ScriptArguments)
|
||||
@ -72,8 +76,13 @@ config = PPOConfig(
|
||||
target_kl=script_args.target_kl,
|
||||
kl_penalty=script_args.kl_penalty,
|
||||
seed=script_args.seed,
|
||||
use_score_scaling=script_args.use_score_scaling,
|
||||
use_score_norm=script_args.use_score_norm,
|
||||
score_clip=script_args.score_clip,
|
||||
)
|
||||
|
||||
# set seed before initializing value head for deterministic eval
|
||||
set_seed(config.seed)
|
||||
|
||||
# We then define the arguments to pass to the sentiment analysis pipeline.
|
||||
# We set `return_all_scores` to True to get the sentiment score for each token.
|
||||
@ -127,9 +136,6 @@ def collator(data):
|
||||
return dict((key, [d[key] for d in data]) for key in data[0])
|
||||
|
||||
|
||||
# set seed before initializing value head for deterministic eval
|
||||
set_seed(config.seed)
|
||||
|
||||
# Now let's build the model, the reference model, and the tokenizer.
|
||||
if not script_args.use_peft:
|
||||
ref_model = trl_model_class.from_pretrained(config.model_name)
|
||||
|
@ -57,6 +57,12 @@ class ScriptArguments:
|
||||
use_auth_token: Optional[bool] = field(default=True, metadata={"help": "Use HF auth token to access the model"})
|
||||
num_train_epochs: Optional[int] = field(default=3, metadata={"help": "the number of training epochs"})
|
||||
max_steps: Optional[int] = field(default=-1, metadata={"help": "the number of training steps"})
|
||||
save_steps: Optional[int] = field(
|
||||
default=100, metadata={"help": "Number of updates steps before two checkpoint saves"}
|
||||
)
|
||||
save_total_limit: Optional[int] = field(default=10, metadata={"help": "Limits total number of checkpoints."})
|
||||
push_to_hub: Optional[bool] = field(default=False, metadata={"help": "Push the model to HF Hub"})
|
||||
hub_model_id: Optional[str] = field(default=None, metadata={"help": "The name of the model on HF Hub"})
|
||||
|
||||
|
||||
parser = HfArgumentParser(ScriptArguments)
|
||||
@ -98,6 +104,11 @@ training_args = TrainingArguments(
|
||||
logging_steps=script_args.logging_steps,
|
||||
num_train_epochs=script_args.num_train_epochs,
|
||||
max_steps=script_args.max_steps,
|
||||
report_to=script_args.log_with,
|
||||
save_steps=script_args.save_steps,
|
||||
save_total_limit=script_args.save_total_limit,
|
||||
push_to_hub=script_args.push_to_hub,
|
||||
hub_model_id=script_args.hub_model_id,
|
||||
)
|
||||
|
||||
# Step 4: Define the LoraConfig
|
||||
|
230
examples/scripts/stable_diffusion_tuning.py
Normal file
230
examples/scripts/stable_diffusion_tuning.py
Normal file
@ -0,0 +1,230 @@
|
||||
# Copyright 2023 metric-space, The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.utils import EntryNotFoundError
|
||||
from transformers import CLIPModel, CLIPProcessor
|
||||
|
||||
from trl import DDPOConfig, DDPOTrainer, DefaultDDPOStableDiffusionPipeline
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.layers = nn.Sequential(
|
||||
nn.Linear(768, 1024),
|
||||
nn.Dropout(0.2),
|
||||
nn.Linear(1024, 128),
|
||||
nn.Dropout(0.2),
|
||||
nn.Linear(128, 64),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(64, 16),
|
||||
nn.Linear(16, 1),
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, embed):
|
||||
return self.layers(embed)
|
||||
|
||||
|
||||
class AestheticScorer(torch.nn.Module):
|
||||
"""
|
||||
This model attempts to predict the aesthetic score of an image. The aesthetic score
|
||||
is a numerical approximation of how much a specific image is liked by humans on average.
|
||||
This is from https://github.com/christophschuhmann/improved-aesthetic-predictor
|
||||
"""
|
||||
|
||||
def __init__(self, *, dtype, model_id, model_filename):
|
||||
super().__init__()
|
||||
self.clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
|
||||
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
|
||||
self.mlp = MLP()
|
||||
try:
|
||||
cached_path = hf_hub_download(model_id, model_filename)
|
||||
except EntryNotFoundError:
|
||||
cached_path = os.path.join(model_id, model_filename)
|
||||
state_dict = torch.load(cached_path)
|
||||
self.mlp.load_state_dict(state_dict)
|
||||
self.dtype = dtype
|
||||
self.eval()
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, images):
|
||||
device = next(self.parameters()).device
|
||||
inputs = self.processor(images=images, return_tensors="pt")
|
||||
inputs = {k: v.to(self.dtype).to(device) for k, v in inputs.items()}
|
||||
embed = self.clip.get_image_features(**inputs)
|
||||
# normalize embedding
|
||||
embed = embed / torch.linalg.vector_norm(embed, dim=-1, keepdim=True)
|
||||
return self.mlp(embed).squeeze(1)
|
||||
|
||||
|
||||
def aesthetic_scorer(hub_model_id, model_filename):
|
||||
scorer = AestheticScorer(
|
||||
model_id=hub_model_id,
|
||||
model_filename=model_filename,
|
||||
dtype=torch.float32,
|
||||
).cuda()
|
||||
|
||||
def _fn(images, prompts, metadata):
|
||||
images = (images * 255).round().clamp(0, 255).to(torch.uint8)
|
||||
scores = scorer(images)
|
||||
return scores, {}
|
||||
|
||||
return _fn
|
||||
|
||||
|
||||
# list of example prompts to feed stable diffusion
|
||||
animals = [
|
||||
"cat",
|
||||
"dog",
|
||||
"horse",
|
||||
"monkey",
|
||||
"rabbit",
|
||||
"zebra",
|
||||
"spider",
|
||||
"bird",
|
||||
"sheep",
|
||||
"deer",
|
||||
"cow",
|
||||
"goat",
|
||||
"lion",
|
||||
"frog",
|
||||
"chicken",
|
||||
"duck",
|
||||
"goose",
|
||||
"bee",
|
||||
"pig",
|
||||
"turkey",
|
||||
"fly",
|
||||
"llama",
|
||||
"camel",
|
||||
"bat",
|
||||
"gorilla",
|
||||
"hedgehog",
|
||||
"kangaroo",
|
||||
]
|
||||
|
||||
|
||||
def prompt_fn():
|
||||
return np.random.choice(animals), {}
|
||||
|
||||
|
||||
def image_outputs_logger(image_data, global_step, accelerate_logger):
|
||||
# For the sake of this example, we will only log the last batch of images
|
||||
# and associated data
|
||||
result = {}
|
||||
images, prompts, _, rewards, _ = image_data[-1]
|
||||
|
||||
for i, image in enumerate(images):
|
||||
prompt = prompts[i]
|
||||
reward = rewards[i].item()
|
||||
result[f"{prompt:.25} | {reward:.2f}"] = image.unsqueeze(0)
|
||||
|
||||
accelerate_logger.log_images(
|
||||
result,
|
||||
step=global_step,
|
||||
)
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(description="DDPOConfig settings and pretrained model details.")
|
||||
|
||||
# DDPOConfig arguments
|
||||
parser.add_argument("--num_epochs", type=int, default=200)
|
||||
parser.add_argument("--train_gradient_accumulation_steps", type=int, default=1)
|
||||
parser.add_argument("--sample_num_steps", type=int, default=50)
|
||||
parser.add_argument("--sample_batch_size", type=int, default=6)
|
||||
parser.add_argument("--train_batch_size", type=int, default=3)
|
||||
parser.add_argument("--sample_num_batches_per_epoch", type=int, default=4)
|
||||
parser.add_argument("--per_prompt_stat_tracking", action="store_true", default=True)
|
||||
parser.add_argument("--per_prompt_stat_tracking_buffer_size", type=int, default=32)
|
||||
parser.add_argument("--tracker_project_name", default="stable_diffusion_training")
|
||||
parser.add_argument("--log_with", default="wandb")
|
||||
|
||||
parser.add_argument("--logging_dir", default="./logs")
|
||||
parser.add_argument("--automatic_checkpoint_naming", action="store_true", default=True)
|
||||
parser.add_argument("--total_limit", type=int, default=5)
|
||||
parser.add_argument("--project_dir", default="./save")
|
||||
|
||||
parser.add_argument("--pretrained_model", default="runwayml/stable-diffusion-v1-5")
|
||||
parser.add_argument("--pretrained_revision", default="main")
|
||||
parser.add_argument("--hf_user_access_token", required=True)
|
||||
parser.add_argument(
|
||||
"--hf_hub_model_id",
|
||||
help="HuggingFace repo to save model weights to",
|
||||
default="ddpo-finetuned-stable-diffusion",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--hf_hub_aesthetic_model_id",
|
||||
help="HuggingFace model ID for aesthetic scorer model weights",
|
||||
default="trl-lib/ddpo-aesthetic-predictor",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--hf_hub_aesthetic_model_filename",
|
||||
default="aesthetic-model.pth",
|
||||
help="HuggingFace model filename for aesthetic scorer model weights",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_arguments()
|
||||
|
||||
project_kwargs = {
|
||||
"logging_dir": args.logging_dir,
|
||||
"automatic_checkpoint_naming": args.automatic_checkpoint_naming,
|
||||
"total_limit": args.total_limit,
|
||||
"project_dir": args.project_dir,
|
||||
}
|
||||
|
||||
config = DDPOConfig(
|
||||
num_epochs=args.num_epochs,
|
||||
train_gradient_accumulation_steps=args.train_gradient_accumulation_steps,
|
||||
sample_num_steps=args.sample_num_steps,
|
||||
sample_batch_size=args.sample_batch_size,
|
||||
train_batch_size=args.train_batch_size,
|
||||
sample_num_batches_per_epoch=args.sample_num_batches_per_epoch,
|
||||
per_prompt_stat_tracking=args.per_prompt_stat_tracking,
|
||||
per_prompt_stat_tracking_buffer_size=args.per_prompt_stat_tracking_buffer_size,
|
||||
tracker_project_name=args.tracker_project_name,
|
||||
log_with=args.log_with,
|
||||
project_kwargs=project_kwargs,
|
||||
)
|
||||
|
||||
pipeline = DefaultDDPOStableDiffusionPipeline(
|
||||
args.pretrained_model, pretrained_model_revision=args.pretrained_revision, use_lora=True
|
||||
)
|
||||
|
||||
trainer = DDPOTrainer(
|
||||
config,
|
||||
aesthetic_scorer(args.hf_hub_aesthetic_model_id, args.hf_hub_aesthetic_model_filename),
|
||||
prompt_fn,
|
||||
pipeline,
|
||||
image_samples_hook=image_outputs_logger,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
||||
trainer.push_to_hub(args.hf_hub_model_id, token=args.hf_user_access_token)
|
@ -3,4 +3,4 @@ torch>=1.4.0
|
||||
tqdm
|
||||
transformers
|
||||
accelerate
|
||||
peft>=0.3.0
|
||||
peft>=0.3.0
|
@ -30,7 +30,7 @@ LABELS_TO_EXEMPT = [
|
||||
|
||||
def main():
|
||||
g = Github(os.environ["GITHUB_TOKEN"])
|
||||
repo = g.get_repo("lvwerra/trl")
|
||||
repo = g.get_repo("huggingface/trl")
|
||||
open_issues = repo.get_issues(state="open")
|
||||
|
||||
for issue in open_issues:
|
||||
|
9
setup.py
9
setup.py
@ -57,7 +57,7 @@ To create the package for pypi.
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
|
||||
__version__ = "0.5.0.dev0" # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
__version__ = "0.7.0" # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
|
||||
REQUIRED_PKGS = [
|
||||
"torch>=1.4.0",
|
||||
@ -67,9 +67,10 @@ REQUIRED_PKGS = [
|
||||
"datasets",
|
||||
]
|
||||
EXTRAS = {
|
||||
"test": ["parameterized", "pytest", "pytest-xdist", "accelerate", "peft"],
|
||||
"test": ["parameterized", "pytest", "pytest-xdist", "accelerate", "peft", "diffusers>=0.18.0"],
|
||||
"peft": ["peft>=0.2.0"],
|
||||
"dev": ["parameterized", "pytest", "pytest-xdist", "pre-commit", "peft>=0.2.0"],
|
||||
"diffusers": ["diffusers>=0.18.0"],
|
||||
"dev": ["parameterized", "pytest", "pytest-xdist", "pre-commit", "peft>=0.2.0", "diffusers>=0.18.0"],
|
||||
}
|
||||
|
||||
setup(
|
||||
@ -88,7 +89,7 @@ setup(
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
],
|
||||
url="https://github.com/lvwerra/trl",
|
||||
url="https://github.com/huggingface/trl",
|
||||
packages=find_packages(),
|
||||
include_package_data=True,
|
||||
install_requires=REQUIRED_PKGS,
|
||||
|
66
tests/test_data_collator_completion_only.py
Normal file
66
tests/test_data_collator_completion_only.py
Normal file
@ -0,0 +1,66 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from trl import DataCollatorForCompletionOnlyLM
|
||||
|
||||
|
||||
class DataCollatorForCompletionOnlyLMTester(unittest.TestCase):
|
||||
def test_data_collator_finds_response_template_llama2_tokenizer(self):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/dummy-GPT2-correct-vocab")
|
||||
self.instruction = """### System: You are a helpful assistant.
|
||||
|
||||
### User: How much is 2+2?
|
||||
|
||||
### Assistant: 2+2 equals 4"""
|
||||
self.response_template = "\n### Assistant:"
|
||||
|
||||
# GPT2Tokenizer: [198, 21017, 15286, 25] -> [15286, 25]
|
||||
# Llama2Tokenizer: [29871, 13, 2277, 29937, 4007, 22137, 29901] -> [2277, 29937, 4007, 22137, 29901]
|
||||
self.tokenized_response_w_context = self.tokenizer.encode(self.response_template, add_special_tokens=False)[2:]
|
||||
|
||||
# Plain check on string
|
||||
self.assertIn(self.response_template, self.instruction)
|
||||
self.tokenized_instruction = self.tokenizer.encode(self.instruction, add_special_tokens=False)
|
||||
|
||||
# Test the fix for #598
|
||||
# Pass already tokenized (w context) and truncated response_template so token_ids are like in the instruction + response
|
||||
self.collator = DataCollatorForCompletionOnlyLM(self.tokenized_response_w_context, tokenizer=self.tokenizer)
|
||||
self.collator.torch_call([self.tokenized_instruction])
|
||||
|
||||
def test_data_collator_handling_of_long_sequences(self):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/dummy-GPT2-correct-vocab")
|
||||
self.instruction = """### System: You are a helpful assistant.
|
||||
|
||||
### User: How much is 2+2? I'm asking because I'm not sure. And I'm not sure because I'm not good at math.
|
||||
"""
|
||||
self.response_template = "\n### Assistant:"
|
||||
# check DataCollatorForCompletionOnlyLM using response template only
|
||||
self.tokenized_instruction = self.tokenizer.encode(self.instruction, add_special_tokens=False)
|
||||
self.collator = DataCollatorForCompletionOnlyLM(self.response_template, tokenizer=self.tokenizer)
|
||||
encoded_instance = self.collator.torch_call([self.tokenized_instruction])
|
||||
result = torch.all(encoded_instance["labels"] == -100)
|
||||
self.assertTrue(result, "Not all values in the tensor are -100.")
|
||||
|
||||
# check DataCollatorForCompletionOnlyLM using response template and instruction template
|
||||
self.instruction_template = "\n### User:"
|
||||
self.collator = DataCollatorForCompletionOnlyLM(
|
||||
self.response_template, self.instruction_template, tokenizer=self.tokenizer
|
||||
)
|
||||
encoded_instance = self.collator.torch_call([self.tokenized_instruction])
|
||||
result = torch.all(encoded_instance["labels"] == -100)
|
||||
self.assertTrue(result, "Not all values in the tensor are -100.")
|
92
tests/test_ddpo_trainer.py
Normal file
92
tests/test_ddpo_trainer.py
Normal file
@ -0,0 +1,92 @@
|
||||
# Copyright 2023 metric-space, The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from trl import DDPOConfig, DDPOTrainer, DefaultDDPOStableDiffusionPipeline
|
||||
|
||||
|
||||
def scorer_function(images, prompts, metadata):
|
||||
return torch.randn(1) * 3.0, {}
|
||||
|
||||
|
||||
def prompt_function():
|
||||
return ("cabbages", {})
|
||||
|
||||
|
||||
class DDPOTrainerTester(unittest.TestCase):
|
||||
"""
|
||||
Test the DDPOTrainer class.
|
||||
"""
|
||||
|
||||
def setUp(self):
|
||||
self.ddpo_config = DDPOConfig(
|
||||
num_epochs=2,
|
||||
train_gradient_accumulation_steps=1,
|
||||
per_prompt_stat_tracking_buffer_size=32,
|
||||
sample_num_batches_per_epoch=2,
|
||||
sample_batch_size=2,
|
||||
mixed_precision=None,
|
||||
save_freq=1000000,
|
||||
)
|
||||
pretrained_model = "hf-internal-testing/tiny-stable-diffusion-torch"
|
||||
pretrained_revision = "main"
|
||||
|
||||
pipeline = DefaultDDPOStableDiffusionPipeline(
|
||||
pretrained_model, pretrained_model_revision=pretrained_revision, use_lora=False
|
||||
)
|
||||
|
||||
self.trainer = DDPOTrainer(self.ddpo_config, scorer_function, prompt_function, pipeline)
|
||||
|
||||
return super().setUp()
|
||||
|
||||
def tearDown(self) -> None:
|
||||
gc.collect()
|
||||
|
||||
def test_loss(self):
|
||||
advantage = torch.tensor([-1.0])
|
||||
clip_range = 0.0001
|
||||
ratio = torch.tensor([1.0])
|
||||
loss = self.trainer.loss(advantage, clip_range, ratio)
|
||||
self.assertEqual(loss.item(), 1.0)
|
||||
|
||||
def test_generate_samples(self):
|
||||
samples, output_pairs = self.trainer._generate_samples(1, 2)
|
||||
self.assertEqual(len(samples), 1)
|
||||
self.assertEqual(len(output_pairs), 1)
|
||||
self.assertEqual(len(output_pairs[0][0]), 2)
|
||||
|
||||
def test_calculate_loss(self):
|
||||
samples, _ = self.trainer._generate_samples(1, 2)
|
||||
sample = samples[0]
|
||||
|
||||
latents = sample["latents"][0, 0].unsqueeze(0)
|
||||
next_latents = sample["next_latents"][0, 0].unsqueeze(0)
|
||||
log_probs = sample["log_probs"][0, 0].unsqueeze(0)
|
||||
timesteps = sample["timesteps"][0, 0].unsqueeze(0)
|
||||
prompt_embeds = sample["prompt_embeds"]
|
||||
advantage = torch.tensor([1.0], device=prompt_embeds.device)
|
||||
|
||||
self.assertEqual(latents.shape, (1, 4, 64, 64))
|
||||
self.assertEqual(next_latents.shape, (1, 4, 64, 64))
|
||||
self.assertEqual(log_probs.shape, (1,))
|
||||
self.assertEqual(timesteps.shape, (1,))
|
||||
self.assertEqual(prompt_embeds.shape, (2, 77, 32))
|
||||
loss, approx_kl, clipfrac = self.trainer.calculate_loss(
|
||||
latents, timesteps, next_latents, log_probs, advantage, prompt_embeds
|
||||
)
|
||||
|
||||
self.assertTrue(torch.isfinite(loss.cpu()))
|
@ -16,10 +16,13 @@ import unittest
|
||||
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from pytest import mark
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
|
||||
|
||||
from trl import DPOTrainer
|
||||
|
||||
from .testing_utils import require_peft
|
||||
|
||||
|
||||
class DPOTrainerTester(unittest.TestCase):
|
||||
@classmethod
|
||||
@ -30,6 +33,40 @@ class DPOTrainerTester(unittest.TestCase):
|
||||
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_id)
|
||||
cls.tokenizer.pad_token = cls.tokenizer.eos_token
|
||||
|
||||
def _init_dummy_dataset(self):
|
||||
# fmt: off
|
||||
dummy_dataset_dict = {
|
||||
"prompt": [
|
||||
"hello",
|
||||
"how are you",
|
||||
"What is your name?",
|
||||
"What is your name?",
|
||||
"Which is the best programming language?",
|
||||
"Which is the best programming language?",
|
||||
"Which is the best programming language?",
|
||||
],
|
||||
"chosen": [
|
||||
"hi nice to meet you",
|
||||
"I am fine",
|
||||
"My name is Mary",
|
||||
"My name is Mary",
|
||||
"Python",
|
||||
"Python",
|
||||
"Python",
|
||||
],
|
||||
"rejected": [
|
||||
"leave me alone",
|
||||
"I am not fine",
|
||||
"Whats it to you?",
|
||||
"I dont have a name",
|
||||
"Javascript",
|
||||
"C++",
|
||||
"Java",
|
||||
],
|
||||
}
|
||||
# fmt: on
|
||||
return Dataset.from_dict(dummy_dataset_dict)
|
||||
|
||||
def test_dpo_trainer(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
training_args = TrainingArguments(
|
||||
@ -42,38 +79,7 @@ class DPOTrainerTester(unittest.TestCase):
|
||||
evaluation_strategy="steps",
|
||||
)
|
||||
|
||||
# fmt: off
|
||||
dummy_dataset_dict = {
|
||||
"prompt": [
|
||||
"hello",
|
||||
"how are you",
|
||||
"What is your name?",
|
||||
"What is your name?",
|
||||
"Which is the best programming language?",
|
||||
"Which is the best programming language?",
|
||||
"Which is the best programming language?",
|
||||
],
|
||||
"chosen": [
|
||||
"hi nice to meet you",
|
||||
"I am fine",
|
||||
"My name is Mary",
|
||||
"My name is Mary",
|
||||
"Python",
|
||||
"Python",
|
||||
"Python",
|
||||
],
|
||||
"rejected": [
|
||||
"leave me alone",
|
||||
"I am not fine",
|
||||
"Whats it to you?",
|
||||
"I dont have a name",
|
||||
"Javascript",
|
||||
"C++",
|
||||
"Java",
|
||||
],
|
||||
}
|
||||
# fmt: on
|
||||
dummy_dataset = Dataset.from_dict(dummy_dataset_dict)
|
||||
dummy_dataset = self._init_dummy_dataset()
|
||||
|
||||
trainer = DPOTrainer(
|
||||
model=self.model,
|
||||
@ -97,3 +103,91 @@ class DPOTrainerTester(unittest.TestCase):
|
||||
# check the params have changed - ignore 0 biases
|
||||
if param.sum() != 0:
|
||||
self.assertFalse(torch.equal(param, new_param))
|
||||
|
||||
def test_dpo_trainer_without_providing_ref_model(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
training_args = TrainingArguments(
|
||||
output_dir=tmp_dir,
|
||||
per_device_train_batch_size=2,
|
||||
max_steps=3,
|
||||
remove_unused_columns=False,
|
||||
gradient_accumulation_steps=4,
|
||||
learning_rate=9e-1,
|
||||
evaluation_strategy="steps",
|
||||
)
|
||||
|
||||
dummy_dataset = self._init_dummy_dataset()
|
||||
|
||||
trainer = DPOTrainer(
|
||||
model=self.model,
|
||||
ref_model=None,
|
||||
beta=0.1,
|
||||
args=training_args,
|
||||
tokenizer=self.tokenizer,
|
||||
train_dataset=dummy_dataset,
|
||||
eval_dataset=dummy_dataset,
|
||||
)
|
||||
|
||||
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
||||
|
||||
trainer.train()
|
||||
|
||||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||
|
||||
# check the params have changed
|
||||
for n, param in previous_trainable_params.items():
|
||||
new_param = trainer.model.get_parameter(n)
|
||||
# check the params have changed - ignore 0 biases
|
||||
if param.sum() != 0:
|
||||
self.assertFalse(torch.equal(param, new_param))
|
||||
|
||||
@require_peft
|
||||
@mark.peft_test
|
||||
def test_dpo_trainer_without_providing_ref_model_with_lora(self):
|
||||
from peft import LoraConfig
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
lora_dropout=0.05,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
training_args = TrainingArguments(
|
||||
output_dir=tmp_dir,
|
||||
per_device_train_batch_size=2,
|
||||
max_steps=3,
|
||||
remove_unused_columns=False,
|
||||
gradient_accumulation_steps=4,
|
||||
learning_rate=9e-1,
|
||||
evaluation_strategy="steps",
|
||||
)
|
||||
|
||||
dummy_dataset = self._init_dummy_dataset()
|
||||
|
||||
trainer = DPOTrainer(
|
||||
model=self.model,
|
||||
ref_model=None,
|
||||
beta=0.1,
|
||||
args=training_args,
|
||||
tokenizer=self.tokenizer,
|
||||
train_dataset=dummy_dataset,
|
||||
eval_dataset=dummy_dataset,
|
||||
peft_config=lora_config,
|
||||
)
|
||||
|
||||
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
||||
|
||||
trainer.train()
|
||||
|
||||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||
|
||||
# check the params have changed
|
||||
for n, param in previous_trainable_params.items():
|
||||
if "lora" in n:
|
||||
new_param = trainer.model.get_parameter(n)
|
||||
# check the params have changed - ignore 0 biases
|
||||
if param.sum() != 0:
|
||||
self.assertFalse(torch.equal(param, new_param))
|
||||
|
273
tests/test_environments.py
Normal file
273
tests/test_environments.py
Normal file
@ -0,0 +1,273 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from trl import AutoModelForCausalLMWithValueHead, TextEnvironment, TextHistory
|
||||
|
||||
|
||||
class DummyTool:
|
||||
def __call__(self, text):
|
||||
return text
|
||||
|
||||
|
||||
def dummy_generate(histories):
|
||||
for i in range(len(histories)):
|
||||
histories[i].append_segment("<request><DummyTool>test<call>", torch.tensor([1, 2, 3]), system=False)
|
||||
return histories
|
||||
|
||||
|
||||
class TextHistoryTest(unittest.TestCase):
|
||||
def test_text_history_init(self):
|
||||
text = "Hello there!"
|
||||
tokens = torch.tensor([1, 2, 3])
|
||||
|
||||
history = TextHistory(text, tokens)
|
||||
self.assertEqual(history.text, text)
|
||||
self.assertTrue(torch.equal(history.tokens, tokens))
|
||||
self.assertTrue(torch.equal(history.token_masks, torch.zeros_like(tokens)))
|
||||
|
||||
history = TextHistory(text, tokens, system=False)
|
||||
self.assertTrue(torch.equal(history.token_masks, torch.ones_like(tokens)))
|
||||
|
||||
def test_text_history_append_segment(self):
|
||||
text = "Hello there!"
|
||||
tokens = torch.tensor([1, 2, 3])
|
||||
|
||||
history = TextHistory(text, tokens)
|
||||
history.append_segment("General Kenobi!", torch.tensor([4, 5, 6]), system=False)
|
||||
self.assertEqual(history.text, text + "General Kenobi!")
|
||||
self.assertTrue(torch.equal(history.tokens, torch.tensor([1, 2, 3, 4, 5, 6])))
|
||||
self.assertTrue(torch.equal(history.token_masks, torch.tensor([0, 0, 0, 1, 1, 1])))
|
||||
|
||||
history.append_segment("You are a bold one!", torch.tensor([7, 8, 9]))
|
||||
self.assertEqual(history.text, text + "General Kenobi!" + "You are a bold one!")
|
||||
self.assertTrue(torch.equal(history.tokens, torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9])))
|
||||
self.assertTrue(torch.equal(history.token_masks, torch.tensor([0, 0, 0, 1, 1, 1, 0, 0, 0])))
|
||||
|
||||
def test_text_history_complete(self):
|
||||
text = "Hello there!"
|
||||
tokens = torch.tensor([1, 2, 3])
|
||||
history = TextHistory(text, tokens)
|
||||
history.complete()
|
||||
self.assertTrue(history.completed)
|
||||
self.assertFalse(history.truncated)
|
||||
|
||||
history.complete(truncated=True)
|
||||
self.assertTrue(history.completed)
|
||||
self.assertTrue(history.truncated)
|
||||
|
||||
def test_text_history_last_segment(self):
|
||||
text = "Hello there!"
|
||||
tokens = torch.tensor([1, 2, 3])
|
||||
history = TextHistory(text, tokens)
|
||||
history.append_segment("General Kenobi!", torch.tensor([4, 5, 6]))
|
||||
history.append_segment("You are a bold one!", torch.tensor([7, 8, 9]))
|
||||
self.assertEqual(history.last_text_segment, "You are a bold one!")
|
||||
|
||||
def test_text_history_split_query_response(self):
|
||||
text = "Hello there!"
|
||||
tokens = torch.tensor([1, 2, 3])
|
||||
history = TextHistory(text, tokens)
|
||||
history.append_segment("General Kenobi!", torch.tensor([4, 5, 6]), system=False)
|
||||
history.append_segment("You are a bold one!", torch.tensor([7, 8, 9]), system=True)
|
||||
query, response, mask = history.split_query_response_tokens()
|
||||
|
||||
self.assertTrue(torch.equal(query, torch.tensor([1, 2, 3])))
|
||||
self.assertTrue(torch.equal(response, torch.tensor([4, 5, 6, 7, 8, 9])))
|
||||
self.assertTrue(torch.equal(mask, torch.tensor([1, 1, 1, 0, 0, 0])))
|
||||
|
||||
|
||||
class TextEnvironmentTester(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
# model_id
|
||||
cls.model_id = "trl-internal-testing/dummy-GPT2-correct-vocab"
|
||||
|
||||
# get models and tokenizer
|
||||
cls.gpt2_model = AutoModelForCausalLMWithValueHead.from_pretrained(cls.model_id)
|
||||
cls.gpt2_tokenizer = AutoTokenizer.from_pretrained(cls.model_id)
|
||||
cls.gpt2_tokenizer.pad_token = cls.gpt2_tokenizer.eos_token
|
||||
|
||||
def test_text_environment_setup(self):
|
||||
env = TextEnvironment(
|
||||
self.gpt2_model,
|
||||
self.gpt2_tokenizer,
|
||||
tools=[DummyTool()],
|
||||
reward_fn=lambda x: torch.tensor(1),
|
||||
prompt="I am a prompt!\n",
|
||||
)
|
||||
self.assertEqual(env.prompt, "I am a prompt!\n")
|
||||
self.assertEqual(list(env.tools.keys()), ["DummyTool"])
|
||||
self.assertTrue(isinstance(env.tools["DummyTool"], DummyTool))
|
||||
self.assertEqual(env.reward_fn("Hello there!"), 1)
|
||||
|
||||
def test_text_environment_generate(self):
|
||||
generation_kwargs = {"do_sample": False, "max_new_tokens": 4, "pad_token_id": self.gpt2_tokenizer.eos_token_id}
|
||||
env = TextEnvironment(
|
||||
self.gpt2_model,
|
||||
self.gpt2_tokenizer,
|
||||
tools=[DummyTool()],
|
||||
reward_fn=lambda x: torch.tensor(1),
|
||||
prompt="I am a prompt!\n",
|
||||
generation_kwargs=generation_kwargs,
|
||||
)
|
||||
|
||||
input_texts = ["this is a test", "this is another, longer test"]
|
||||
|
||||
model_inputs = [self.gpt2_tokenizer(txt, return_tensors="pt").input_ids.squeeze() for txt in input_texts]
|
||||
|
||||
generations_batched = env._generate_batched(model_inputs, batch_size=2)
|
||||
generations_batched = self.gpt2_tokenizer.batch_decode(generations_batched)
|
||||
|
||||
generations_single = [env._generate_batched([inputs], batch_size=1)[0] for inputs in model_inputs]
|
||||
generations_single = self.gpt2_tokenizer.batch_decode(generations_single)
|
||||
|
||||
self.assertEqual(generations_single, generations_batched)
|
||||
|
||||
def test_text_environment_tool_call_parsing(self):
|
||||
string_valid = "Something something <request><Tool1>Hello there!<call>"
|
||||
string_invalid_request = "Something something <Tool1>Hello there!<call>"
|
||||
string_invalid_call = "Something something <request><Tool1>Hello there!"
|
||||
string_invalid_tool = "Something something <request>|Tool2|Hello there!<call>"
|
||||
string_invalid_random = "<>abcdefghijklm<>nopqrstuvwxyz<>"
|
||||
|
||||
env = TextEnvironment(
|
||||
self.gpt2_model,
|
||||
self.gpt2_tokenizer,
|
||||
tools=[DummyTool()],
|
||||
reward_fn=lambda x: torch.tensor(1),
|
||||
prompt="I am a prompt!\n",
|
||||
)
|
||||
tool, response = env.parse_tool_call(string_valid)
|
||||
self.assertEqual(tool, "Tool1")
|
||||
self.assertEqual(response, "Hello there!")
|
||||
|
||||
tool, response = env.parse_tool_call(string_invalid_request)
|
||||
self.assertEqual(tool, None)
|
||||
self.assertEqual(response, None)
|
||||
|
||||
tool, response = env.parse_tool_call(string_invalid_call)
|
||||
self.assertEqual(tool, None)
|
||||
self.assertEqual(response, None)
|
||||
|
||||
tool, response = env.parse_tool_call(string_invalid_tool)
|
||||
self.assertEqual(tool, None)
|
||||
self.assertEqual(response, None)
|
||||
|
||||
tool, response = env.parse_tool_call(string_invalid_random)
|
||||
self.assertEqual(tool, None)
|
||||
self.assertEqual(response, None)
|
||||
|
||||
def test_text_environment_tool_truncation(self):
|
||||
env = TextEnvironment(
|
||||
self.gpt2_model,
|
||||
self.gpt2_tokenizer,
|
||||
tools={"dummy": lambda x: "a" * 1000},
|
||||
reward_fn=lambda x: torch.tensor(1),
|
||||
prompt="I am a prompt!\n",
|
||||
)
|
||||
|
||||
env.max_tool_response = 100
|
||||
history = env.step(TextHistory("<request><dummy>Hello there!<call>", torch.tensor([1, 2, 3])))
|
||||
self.assertEqual(len(history.last_text_segment) - len(env.response_token), 100)
|
||||
|
||||
env.max_tool_response = 500
|
||||
history = env.step(TextHistory("<request><dummy>Hello there!<call>", torch.tensor([1, 2, 3])))
|
||||
self.assertEqual(len(history.last_text_segment) - len(env.response_token), 500)
|
||||
|
||||
env.max_tool_response = 1001
|
||||
history = env.step(TextHistory("<request><dummy>Hello there!<call>", torch.tensor([1, 2, 3])))
|
||||
self.assertEqual(len(history.last_text_segment) - len(env.response_token), 1000)
|
||||
|
||||
env.max_tool_response = 2000
|
||||
history = env.step(TextHistory("<request><dummy>Hello there!<call>", torch.tensor([1, 2, 3])))
|
||||
self.assertEqual(len(history.last_text_segment) - len(env.response_token), 1000)
|
||||
|
||||
@patch.object(TextEnvironment, "generate", side_effect=dummy_generate)
|
||||
def test_text_environment_max_calls(self, mock_generate):
|
||||
env = TextEnvironment(
|
||||
self.gpt2_model,
|
||||
self.gpt2_tokenizer,
|
||||
tools={"DummyTool": DummyTool()},
|
||||
reward_fn=lambda x: [torch.tensor(1) for _ in x],
|
||||
prompt="I am a prompt!\n",
|
||||
)
|
||||
|
||||
env.max_turns = 1
|
||||
_, _, _, _, histories = env.run(["test"])
|
||||
self.assertEqual(
|
||||
histories[0].text, "I am a prompt!\n" + "test" + 1 * "<request><DummyTool>test<call>test<response>"
|
||||
)
|
||||
|
||||
env.max_turns = 2
|
||||
_, _, _, _, histories = env.run(["test"])
|
||||
self.assertEqual(
|
||||
histories[0].text, "I am a prompt!\n" + "test" + 2 * "<request><DummyTool>test<call>test<response>"
|
||||
)
|
||||
|
||||
env.max_turns = 4
|
||||
_, _, _, _, histories = env.run(["test"])
|
||||
self.assertEqual(
|
||||
histories[0].text, "I am a prompt!\n" + "test" + 4 * "<request><DummyTool>test<call>test<response>"
|
||||
)
|
||||
|
||||
def test_text_environment_compute_rewards(self):
|
||||
env = TextEnvironment(
|
||||
self.gpt2_model,
|
||||
self.gpt2_tokenizer,
|
||||
tools={"DummyTool": DummyTool()},
|
||||
reward_fn=lambda x: [torch.tensor(i) for i, _ in enumerate(x)],
|
||||
prompt="I am a prompt!\n",
|
||||
)
|
||||
|
||||
histories = [TextHistory("<request><DummyTool>test<call>", torch.tensor([1, 2, 3])) for _ in range(8)]
|
||||
histories = env.compute_reward(histories)
|
||||
|
||||
for i in range(8):
|
||||
self.assertEqual(histories[i].reward, i)
|
||||
|
||||
@patch.object(TextEnvironment, "generate", side_effect=dummy_generate)
|
||||
def test_text_environment_run(self, mock_generate):
|
||||
env = TextEnvironment(
|
||||
self.gpt2_model,
|
||||
self.gpt2_tokenizer,
|
||||
tools={"DummyTool": DummyTool()},
|
||||
reward_fn=lambda x: [torch.tensor(i) for i, _ in enumerate(x)],
|
||||
prompt="I am a prompt!\n",
|
||||
max_turns=2,
|
||||
)
|
||||
task_1 = "Hello there!"
|
||||
task_2 = "Hello there! General Kenobi!"
|
||||
|
||||
query, response, response_mask, reward, histories = env.run([task_1, task_2])
|
||||
self.assertEqual(len(query[0]), 9)
|
||||
self.assertEqual(len(query[1]), 12)
|
||||
self.assertEqual(len(response[0]), 14)
|
||||
self.assertEqual(len(response[1]), 14)
|
||||
self.assertEqual(response_mask[0].sum(), 2 * 3) # mocked generate always adds 3 toknes
|
||||
self.assertEqual(response_mask[1].sum(), 2 * 3) # mocked generate always adds 3 toknes
|
||||
self.assertEqual(reward[0], 0)
|
||||
self.assertEqual(reward[1], 1)
|
||||
self.assertEqual(
|
||||
histories[0].text, "I am a prompt!\n" + "Hello there!" + 2 * "<request><DummyTool>test<call>test<response>"
|
||||
)
|
||||
self.assertEqual(
|
||||
histories[1].text,
|
||||
"I am a prompt!\n" + "Hello there! General Kenobi!" + 2 * "<request><DummyTool>test<call>test<response>",
|
||||
)
|
@ -18,6 +18,7 @@ import re
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from huggingface_hub import HfApi, HfFolder, delete_repo
|
||||
from parameterized import parameterized
|
||||
@ -208,6 +209,38 @@ class PPOTrainerTester(unittest.TestCase):
|
||||
for stat in EXPECTED_STATS:
|
||||
assert stat in train_stats.keys()
|
||||
|
||||
def test_ppo_step_with_masks(self):
|
||||
# initialize dataset
|
||||
dummy_dataset = self._init_dummy_dataset()
|
||||
|
||||
ppo_trainer = PPOTrainer(
|
||||
config=self.ppo_config,
|
||||
model=self.gpt2_model,
|
||||
ref_model=self.gpt2_model_ref,
|
||||
tokenizer=self.gpt2_tokenizer,
|
||||
dataset=dummy_dataset,
|
||||
)
|
||||
dummy_dataloader = ppo_trainer.dataloader
|
||||
# train model with ppo
|
||||
for query_tensor, response_tensor in dummy_dataloader:
|
||||
# define a reward for response
|
||||
# (this could be any reward such as human feedback or output from another model)
|
||||
reward = [torch.tensor(1.0), torch.tensor(0.0)]
|
||||
|
||||
response_mask = [torch.ones_like(r) for r in response_tensor]
|
||||
|
||||
# train model
|
||||
train_stats = ppo_trainer.step(
|
||||
[q for q in query_tensor], [r for r in response_tensor], reward, response_mask
|
||||
)
|
||||
break
|
||||
|
||||
for param in ppo_trainer.model.parameters():
|
||||
assert param.grad is not None
|
||||
|
||||
for stat in EXPECTED_STATS:
|
||||
assert stat in train_stats.keys()
|
||||
|
||||
def test_ppo_step_with_no_ref_sgd(self):
|
||||
# initialize dataset
|
||||
dummy_dataset = self._init_dummy_dataset()
|
||||
@ -465,7 +498,7 @@ class PPOTrainerTester(unittest.TestCase):
|
||||
# train model - this should raise an error
|
||||
bs = ppo_trainer.config.batch_size
|
||||
|
||||
queries, responses, _ = ppo_trainer._step_safety_checker(
|
||||
queries, responses, _, _ = ppo_trainer._step_safety_checker(
|
||||
bs, [q for q in query_tensor], [r for r in response_tensor], reward
|
||||
)
|
||||
|
||||
@ -1193,3 +1226,7 @@ class PPOTrainerTester(unittest.TestCase):
|
||||
# train model
|
||||
_ = ppo_trainer.step([q for q in query_tensor], [r for r in response_tensor], reward)
|
||||
break
|
||||
|
||||
def test_batch_size_check(self):
|
||||
with pytest.raises(ValueError):
|
||||
PPOConfig(batch_size=2, mini_batch_size=2, gradient_accumulation_steps=2)
|
||||
|
@ -1,10 +1,11 @@
|
||||
# flake8: noqa
|
||||
|
||||
__version__ = "0.4.8.dev0"
|
||||
__version__ = "0.7.0"
|
||||
|
||||
from .core import set_seed
|
||||
from .environment import TextEnvironment, TextHistory
|
||||
from .extras import BestOfNSampler
|
||||
from .import_utils import is_peft_available
|
||||
from .import_utils import is_diffusers_available, is_peft_available
|
||||
from .models import (
|
||||
AutoModelForCausalLMWithValueHead,
|
||||
AutoModelForSeq2SeqLMWithValueHead,
|
||||
@ -12,3 +13,13 @@ from .models import (
|
||||
create_reference_model,
|
||||
)
|
||||
from .trainer import DataCollatorForCompletionOnlyLM, DPOTrainer, PPOConfig, PPOTrainer, RewardTrainer, SFTTrainer
|
||||
|
||||
|
||||
if is_diffusers_available():
|
||||
from .models import (
|
||||
DDPOPipelineOutput,
|
||||
DDPOSchedulerOutput,
|
||||
DDPOStableDiffusionPipeline,
|
||||
DefaultDDPOStableDiffusionPipeline,
|
||||
)
|
||||
from .trainer import DDPOConfig, DDPOTrainer
|
||||
|
3
trl/environment/__init__.py
Normal file
3
trl/environment/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
# flake8: noqa
|
||||
|
||||
from .base_environment import TextEnvironment, TextHistory
|
473
trl/environment/base_environment.py
Normal file
473
trl/environment/base_environment.py
Normal file
@ -0,0 +1,473 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
from accelerate.utils import extract_model_from_parallel
|
||||
from transformers import StoppingCriteria, StoppingCriteriaList
|
||||
|
||||
from ..import_utils import is_rich_available
|
||||
|
||||
|
||||
if is_rich_available():
|
||||
from rich import print
|
||||
from rich.text import Text
|
||||
|
||||
|
||||
class StringStoppingCriteria(StoppingCriteria):
|
||||
"""Custom `StoppingCriteria` which checks if all generations in the batch are completed."""
|
||||
|
||||
def __init__(self, stop_strings, tokenizer):
|
||||
self.stop_strings = stop_strings
|
||||
self.tokenizer = tokenizer
|
||||
self.first_call = True
|
||||
|
||||
def __call__(self, input_ids, scores, **kwargs):
|
||||
"""Returns true if all generated sequences contain any of the stop strings."""
|
||||
if self.first_call:
|
||||
self.generated_tokens = [1 for _ in range(input_ids.shape[0])]
|
||||
self.start_length = input_ids.shape[-1] - 1
|
||||
self.first_call = False
|
||||
decoded_generations = self.tokenizer.batch_decode(input_ids[:, self.start_length :])
|
||||
done = []
|
||||
|
||||
for i, decoded_generation in enumerate(decoded_generations):
|
||||
sequence_complete = any([stop_string in decoded_generation for stop_string in self.stop_strings])
|
||||
done.append(sequence_complete)
|
||||
if not sequence_complete:
|
||||
self.generated_tokens[i] += 1
|
||||
|
||||
if all(done):
|
||||
self.first_call = True
|
||||
|
||||
return all(done)
|
||||
|
||||
|
||||
class TextHistory:
|
||||
"""The TextHistory class keeps track of the history of an interaction between the language model and the environment."""
|
||||
|
||||
def __init__(self, text, tokens, system=True):
|
||||
"""
|
||||
Initialize TextHistory.
|
||||
|
||||
args:
|
||||
text (`str`): The text of the first segment.
|
||||
tokens (`torch.LongTensor`): The tokens of the first segment.
|
||||
system (`bool`, *optional*): Whether the first segment is a system or user segment.
|
||||
"""
|
||||
self.system_spans = []
|
||||
self.text_spans = []
|
||||
self.token_spans = []
|
||||
self.token_masks = torch.tensor([], dtype=torch.long).to(tokens.device)
|
||||
self.text = ""
|
||||
self.tokens = torch.tensor([], dtype=torch.long).to(tokens.device)
|
||||
self.completed = False
|
||||
self.truncated = False
|
||||
self.reward = 0.0
|
||||
|
||||
self.prompt_color = "black on grey85"
|
||||
self.system_color = "black on cyan3"
|
||||
self.model_color = "black on deep_sky_blue1"
|
||||
self.reward_color = "black on plum1"
|
||||
|
||||
self.append_segment(text, tokens, system=system)
|
||||
|
||||
def append_segment(self, text, tokens, system=True):
|
||||
"""
|
||||
Append a new segment to the history.
|
||||
|
||||
args:
|
||||
text (`str`): The text of the new segment.
|
||||
tokens (`torch.LongTensor`): The tokens of the new segment.
|
||||
system (`bool`, *optional*): Whether the new segment is a system or user segment.
|
||||
"""
|
||||
|
||||
if len(text) == 0 or len(tokens) == 0:
|
||||
raise ValueError("Can't append empty text or token list to history.")
|
||||
|
||||
original_text_length = len(self.text)
|
||||
|
||||
self.text += text
|
||||
self.text_spans.append((original_text_length, len(self.text)))
|
||||
self.system_spans.append(system)
|
||||
|
||||
original_token_length = len(self.tokens)
|
||||
|
||||
self.tokens = torch.cat((self.tokens, tokens))
|
||||
if system:
|
||||
self.token_masks = torch.cat((self.token_masks, torch.zeros_like(tokens)))
|
||||
else:
|
||||
self.token_masks = torch.cat((self.token_masks, torch.ones_like(tokens)))
|
||||
self.token_spans.append((original_token_length, len(self.tokens)))
|
||||
|
||||
def complete(self, truncated=False):
|
||||
"""
|
||||
Mark the history as completed.
|
||||
"""
|
||||
self.completed = True
|
||||
self.truncated = truncated
|
||||
|
||||
@property
|
||||
def last_text_segment(self):
|
||||
"""
|
||||
Get the last text segment.
|
||||
"""
|
||||
start, end = self.text_spans[-1]
|
||||
return self.text[start:end]
|
||||
|
||||
def split_query_response_tokens(self):
|
||||
"""
|
||||
Split the tokens into query and response tokens.
|
||||
"""
|
||||
split_index = self.token_spans[0][1]
|
||||
query = self.tokens[:split_index]
|
||||
response = self.tokens[split_index:]
|
||||
mask = self.token_masks[split_index:]
|
||||
|
||||
return query, response, mask
|
||||
|
||||
def show_text(self, show_legend=False):
|
||||
"""
|
||||
Print the text history.
|
||||
"""
|
||||
if not is_rich_available():
|
||||
warnings.warn("install rich to display text")
|
||||
return
|
||||
|
||||
text = Text(self.text)
|
||||
text.stylize(self.prompt_color, self.text_spans[0][0], self.text_spans[1][0])
|
||||
for i, (start, end) in enumerate(self.text_spans[1:]):
|
||||
if self.system_spans[i + 1]:
|
||||
text.stylize(self.system_color, start, end)
|
||||
else:
|
||||
text.stylize(self.model_color, start, end)
|
||||
|
||||
text.append(f"\n\nReward: {self.reward}", style=self.reward_color)
|
||||
print(text)
|
||||
|
||||
if show_legend:
|
||||
self.show_colour_legend()
|
||||
|
||||
def show_tokens(self, tokenizer, show_legend=False):
|
||||
"""
|
||||
Print the history tokens.
|
||||
"""
|
||||
if not is_rich_available():
|
||||
warnings.warn("install rich to display tokens")
|
||||
return
|
||||
|
||||
text = Text()
|
||||
prompt_end = self.token_spans[0][1]
|
||||
for i, (token, mask) in enumerate(zip(self.tokens, self.token_masks)):
|
||||
if i < prompt_end:
|
||||
text.append(tokenizer.convert_ids_to_tokens(token.item()), style=self.prompt_color)
|
||||
text.append(" ")
|
||||
elif mask == 0:
|
||||
text.append(tokenizer.convert_ids_to_tokens(token.item()), style=self.system_color)
|
||||
text.append(" ")
|
||||
else:
|
||||
text.append(tokenizer.convert_ids_to_tokens(token.item()), style=self.model_color)
|
||||
text.append(" ")
|
||||
text.append(f"\n\nReward: {self.reward}", style=self.reward_color)
|
||||
print(text)
|
||||
if show_legend:
|
||||
self.show_colour_legend()
|
||||
|
||||
def show_colour_legend(self):
|
||||
"""
|
||||
Print the colour legend.
|
||||
"""
|
||||
if not is_rich_available():
|
||||
warnings.warn("install rich to display colour legend")
|
||||
return
|
||||
text = Text("\n\n(Colour Legend: ")
|
||||
text.append("Prompt", style=self.prompt_color)
|
||||
text.append("|")
|
||||
text.append("System", style=self.system_color)
|
||||
text.append("|")
|
||||
text.append("Model", style=self.model_color)
|
||||
text.append("|")
|
||||
text.append("Reward", style=self.reward_color)
|
||||
text.append(")")
|
||||
print(text)
|
||||
|
||||
|
||||
class TextEnvironment:
|
||||
"""
|
||||
The TextEnvironment enables interaction of a LLM with an environment using tools.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model=None,
|
||||
tokenizer=None,
|
||||
tools=None,
|
||||
reward_fn=None,
|
||||
prompt=None,
|
||||
max_turns=4,
|
||||
max_tool_reponse=100,
|
||||
max_length=None,
|
||||
generation_kwargs=None,
|
||||
):
|
||||
"""
|
||||
Initialize TextEnvironment.
|
||||
|
||||
Args:
|
||||
model (`PreTrainedModelWrapper`): The model to use for generation.
|
||||
tokenizer (`transformers.PreTrainedTokenizer`): The tokenizer to use for generation.
|
||||
tools (list): A list of tools to use for interaction.
|
||||
reward_fn (function): A function that takes a string and returns a reward.
|
||||
prompt (str): The base prompt to use for generation. Is prepended to the tasks.
|
||||
max_turns (Optional[int]): The maximum number of turns to allow.
|
||||
max_tool_response (Optional[int]): The maximum number of characters to allow in a tool response.
|
||||
max_length (Optional[int]): The maximum number of tokens to allow in an episode.
|
||||
generation_kwargs (Optional[dict]): A dictionary of keyword arguments to pass to the model's generate method.
|
||||
"""
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.prompt = prompt
|
||||
if isinstance(tools, dict):
|
||||
self.tools = tools
|
||||
else:
|
||||
self.tools = dict([(tool.__class__.__name__, tool) for tool in tools])
|
||||
self.reward_fn = reward_fn
|
||||
self.max_length = max_length
|
||||
self.request_token = "<request>"
|
||||
self.call_token = "<call>"
|
||||
self.response_token = "<response>"
|
||||
self.submit_token = "<submit>"
|
||||
self.max_turns = max_turns
|
||||
self.max_tool_response = max_tool_reponse
|
||||
|
||||
if generation_kwargs is None:
|
||||
self.generation_kwargs = dict()
|
||||
else:
|
||||
self.generation_kwargs = generation_kwargs
|
||||
|
||||
self.is_encoder_decoder = hasattr(self.model, "is_encoder_decoder")
|
||||
self.current_device = extract_model_from_parallel(self.model).pretrained_model.device
|
||||
|
||||
def run(self, queries, **rewards_kwargs):
|
||||
"""
|
||||
Run the environment on a list of queries.
|
||||
|
||||
Args:
|
||||
queries (list[str]): A list of queries to run the model in the environment on.
|
||||
"""
|
||||
turns = 0
|
||||
|
||||
queries = [self.prompt + task for task in queries]
|
||||
queries_tokens = [
|
||||
self.tokenizer(query, return_tensors="pt").input_ids[0].to(self.model.pretrained_model.device)
|
||||
for query in queries
|
||||
]
|
||||
|
||||
histories = [TextHistory(q, qt, system=True) for q, qt in zip(queries, queries_tokens)]
|
||||
|
||||
while any([not history.completed for history in histories]) and turns < self.max_turns:
|
||||
histories = self.generate(histories)
|
||||
histories = self.tasks_end_check(histories)
|
||||
# TODO: make this parallel rather than for-loop
|
||||
for i in range(len(histories)):
|
||||
histories[i] = self.step(histories[i])
|
||||
histories = self.tasks_end_check(histories, model_turn=False)
|
||||
turns += 1
|
||||
self.compute_reward(histories, **rewards_kwargs)
|
||||
|
||||
# convert a list of (q, r, m) tuples to lists of all qs, rs, and ms respectively
|
||||
queries, responses, masks = map(list, zip(*[history.split_query_response_tokens() for history in histories]))
|
||||
|
||||
rewards = [history.reward for history in histories]
|
||||
return queries, responses, masks, rewards, histories
|
||||
|
||||
def step(self, history):
|
||||
"""
|
||||
Step the environment forward one turn.
|
||||
|
||||
Args:
|
||||
history (`TextHistory`): The history to step forward.
|
||||
"""
|
||||
truncated, ended = self.task_end_check(history)
|
||||
if ended:
|
||||
history.complete(truncated=truncated)
|
||||
if history.completed:
|
||||
return history
|
||||
|
||||
tool, query = self.parse_tool_call(history.last_text_segment)
|
||||
if tool is None or query is None:
|
||||
response = f"Unknown tool call: {history.last_text_segment}"
|
||||
else:
|
||||
if tool not in self.tools:
|
||||
response = f"Unknown tool {tool}."
|
||||
try:
|
||||
response = self.tools[tool](query)
|
||||
except Exception as error:
|
||||
response = f"Tool error: {str(error)}"
|
||||
|
||||
if len(response) > self.max_tool_response:
|
||||
response = response[: (self.max_tool_response - 3)] + "..."
|
||||
|
||||
history.append_segment(
|
||||
response + self.response_token,
|
||||
self.tokenizer(response + self.response_token, return_tensors="pt")
|
||||
.input_ids[0]
|
||||
.to(self.model.pretrained_model.device),
|
||||
system=True,
|
||||
)
|
||||
|
||||
return history
|
||||
|
||||
def parse_tool_call(self, text):
|
||||
"""
|
||||
Parse request string. Expected format: <request><tool_name>query<call>
|
||||
"""
|
||||
result = re.search(f"(?<={self.request_token}).*?(?={self.call_token})", text, re.DOTALL)
|
||||
|
||||
# if we can't find a <request>/<call> span we return none
|
||||
if result is None:
|
||||
return None, None
|
||||
else:
|
||||
extracted_text = result.group()
|
||||
|
||||
result = re.search(r"<(.*?)>", extracted_text)
|
||||
|
||||
# if we can't find a tool name we return none
|
||||
if result is None:
|
||||
return None, None
|
||||
else:
|
||||
tool = result.group(1)
|
||||
|
||||
# split off the tool name
|
||||
query = ">".join(extracted_text.split(">")[1:])
|
||||
|
||||
return tool, query
|
||||
|
||||
def compute_reward(self, histories, **reward_kwargs):
|
||||
"""
|
||||
Compute the reward for a list of histories.
|
||||
"""
|
||||
rewards = self.reward_fn([history.last_text_segment for history in histories], **reward_kwargs)
|
||||
for history, reward in zip(histories, rewards):
|
||||
history.reward = reward
|
||||
return histories
|
||||
|
||||
def generate(self, histories):
|
||||
"""
|
||||
Generate responses for a list of histories.
|
||||
"""
|
||||
active_histories = [i for i, history in enumerate(histories) if not history.completed]
|
||||
|
||||
query_tensors = [histories[i].tokens for i in active_histories]
|
||||
response_tensors = self._generate_batched(query_tensors)
|
||||
response_texts = self.tokenizer.batch_decode(response_tensors)
|
||||
|
||||
for i, response_text, response_tensor in zip(active_histories, response_texts, response_tensors):
|
||||
histories[i].append_segment(response_text, response_tensor, system=False)
|
||||
|
||||
return histories
|
||||
|
||||
def tasks_end_check(self, histories, model_turn=True):
|
||||
"""
|
||||
Check if the current generation sequences have finished.
|
||||
"""
|
||||
for history in histories:
|
||||
if not history.completed:
|
||||
truncated, ended = self.task_end_check(history, model_turn=model_turn)
|
||||
if ended:
|
||||
history.complete(truncated=truncated)
|
||||
return histories
|
||||
|
||||
def task_end_check(self, history, model_turn=True):
|
||||
"""
|
||||
Check if the current generation sequence has finished.
|
||||
"""
|
||||
truncated = False
|
||||
ended = False
|
||||
if history.completed:
|
||||
return truncated, ended
|
||||
if self.max_length is not None and len(self.tokenizer(history.text).input_ids[0]) > self.max_length:
|
||||
truncated = True
|
||||
ended = True
|
||||
elif self.tokenizer.eos_token in history.text:
|
||||
ended = True
|
||||
elif model_turn and not (
|
||||
(self.request_token in history.last_text_segment and self.call_token in history.last_text_segment)
|
||||
or self.submit_token in history.last_text_segment
|
||||
):
|
||||
ended = True
|
||||
elif self.submit_token in history.last_text_segment:
|
||||
ended = True
|
||||
return truncated, ended
|
||||
|
||||
def _generate_batched(
|
||||
self,
|
||||
query_tensors,
|
||||
batch_size: int = 16,
|
||||
pad_to_multiple_of: int = None,
|
||||
):
|
||||
"""
|
||||
Generate responses for a list of query tensors.
|
||||
|
||||
args:
|
||||
query_tensors (list[torch.Tensor]): A list of query tensors to generate responses for.
|
||||
batch_size (int): The batch size to use for generation.
|
||||
pad_to_multiple_of (int): The padding length to use for generation.
|
||||
"""
|
||||
outputs = []
|
||||
padding_side_default = self.tokenizer.padding_side
|
||||
if not self.is_encoder_decoder:
|
||||
self.tokenizer.padding_side = "left"
|
||||
|
||||
# in case we have fewer examples than bs
|
||||
batch_size = min(len(query_tensors), batch_size)
|
||||
|
||||
for i in range(0, len(query_tensors), batch_size):
|
||||
# prevent overflow if query tensors are not even multiple of bs
|
||||
end_index = min(len(query_tensors), i + batch_size)
|
||||
|
||||
batch = query_tensors[i:end_index]
|
||||
batch_mask = [torch.ones_like(element) for element in batch]
|
||||
inputs = {"input_ids": batch, "attention_mask": batch_mask}
|
||||
|
||||
padded_inputs = self.tokenizer.pad(
|
||||
inputs,
|
||||
padding=True,
|
||||
max_length=None,
|
||||
pad_to_multiple_of=pad_to_multiple_of,
|
||||
return_tensors="pt",
|
||||
).to(self.current_device)
|
||||
|
||||
stopping_criteria = StringStoppingCriteria([self.call_token, self.submit_token], self.tokenizer)
|
||||
|
||||
self.generation_kwargs["stopping_criteria"] = StoppingCriteriaList([stopping_criteria])
|
||||
|
||||
generations = extract_model_from_parallel(self.model).generate(**padded_inputs, **self.generation_kwargs)
|
||||
|
||||
for generation, mask, generated_tokens in zip(
|
||||
generations, padded_inputs["attention_mask"], stopping_criteria.generated_tokens
|
||||
):
|
||||
if not self.is_encoder_decoder:
|
||||
output = generation[(1 - mask).sum() :] # remove padding
|
||||
else:
|
||||
output = generation
|
||||
|
||||
if not self.is_encoder_decoder:
|
||||
output = output[(mask).sum() :] # remove prompt
|
||||
|
||||
# remove chunk generated after stopping criteria in batch mode
|
||||
outputs.append(output[:generated_tokens])
|
||||
self.tokenizer.padding_side = padding_side_default
|
||||
return outputs
|
@ -35,3 +35,19 @@ def is_torch_greater_2_0():
|
||||
|
||||
torch_version = pkg_resources.get_distribution("torch").version
|
||||
return torch_version >= "2.0"
|
||||
|
||||
|
||||
def is_diffusers_available():
|
||||
return importlib.util.find_spec("diffusers") is not None
|
||||
|
||||
|
||||
def is_bitsandbytes_available():
|
||||
return importlib.util.find_spec("bitsandbytes") is not None
|
||||
|
||||
|
||||
def is_torchvision_available():
|
||||
return importlib.util.find_spec("torchvision") is not None
|
||||
|
||||
|
||||
def is_rich_available():
|
||||
return importlib.util.find_spec("rich") is not None
|
||||
|
@ -21,3 +21,14 @@ SUPPORTED_ARCHITECTURES = (
|
||||
AutoModelForCausalLMWithValueHead,
|
||||
AutoModelForSeq2SeqLMWithValueHead,
|
||||
)
|
||||
|
||||
from ..import_utils import is_diffusers_available
|
||||
|
||||
|
||||
if is_diffusers_available():
|
||||
from .modeling_sd_base import (
|
||||
DDPOPipelineOutput,
|
||||
DDPOSchedulerOutput,
|
||||
DDPOStableDiffusionPipeline,
|
||||
DefaultDDPOStableDiffusionPipeline,
|
||||
)
|
||||
|
@ -20,6 +20,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from accelerate import Accelerator
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.utils import EntryNotFoundError, HFValidationError, LocalEntryNotFoundError
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
from ..import_utils import is_peft_available
|
||||
@ -115,12 +116,14 @@ class PreTrainedModelWrapper(nn.Module):
|
||||
reward_adapter = kwargs.pop("reward_adapter", None)
|
||||
is_trainable = kwargs.pop("is_trainable", False)
|
||||
trl_model_args, pretrained_kwargs, peft_quantization_kwargs = cls._split_kwargs(kwargs)
|
||||
token = pretrained_kwargs.get("token", None)
|
||||
else:
|
||||
peft_config = None
|
||||
is_trainable = False
|
||||
trl_model_args = {}
|
||||
pretrained_kwargs = {}
|
||||
peft_quantization_kwargs = {}
|
||||
token = None
|
||||
|
||||
if reward_adapter is not None and not isinstance(reward_adapter, str):
|
||||
raise ValueError(
|
||||
@ -156,8 +159,12 @@ class PreTrainedModelWrapper(nn.Module):
|
||||
if is_peft_available():
|
||||
try:
|
||||
# If there is a trained peft adapter in the hub, load its config.
|
||||
remote_adapter_config = hf_hub_download(pretrained_model_name_or_path, "adapter_config.json")
|
||||
except: # noqa
|
||||
remote_adapter_config = hf_hub_download(
|
||||
pretrained_model_name_or_path,
|
||||
"adapter_config.json",
|
||||
token=token,
|
||||
)
|
||||
except (EntryNotFoundError, LocalEntryNotFoundError, HFValidationError):
|
||||
remote_adapter_config = None
|
||||
else:
|
||||
remote_adapter_config = None
|
||||
@ -175,7 +182,8 @@ class PreTrainedModelWrapper(nn.Module):
|
||||
if local_adapter_present:
|
||||
trained_adapter_config = PeftConfig.from_pretrained(pretrained_model_name_or_path)
|
||||
else:
|
||||
trained_adapter_config = PeftConfig.from_pretrained(remote_adapter_config)
|
||||
remote_adapter_dir = os.path.dirname(remote_adapter_config)
|
||||
trained_adapter_config = PeftConfig.from_pretrained(remote_adapter_dir)
|
||||
|
||||
# Load the pretrained base model
|
||||
pretrained_model = cls.transformers_parent_class.from_pretrained(
|
||||
@ -241,17 +249,24 @@ class PreTrainedModelWrapper(nn.Module):
|
||||
|
||||
if not os.path.exists(filename):
|
||||
try:
|
||||
filename = hf_hub_download(pretrained_model_name_or_path, "pytorch_model.bin")
|
||||
filename = hf_hub_download(
|
||||
pretrained_model_name_or_path,
|
||||
"pytorch_model.bin",
|
||||
token=token,
|
||||
)
|
||||
# sharded
|
||||
except: # noqa
|
||||
except (EntryNotFoundError, LocalEntryNotFoundError, HFValidationError):
|
||||
if os.path.exists(sharded_index_filename):
|
||||
index_file_name = sharded_index_filename
|
||||
else:
|
||||
try:
|
||||
index_file_name = hf_hub_download(
|
||||
pretrained_model_name_or_path, "pytorch_model.bin.index.json"
|
||||
pretrained_model_name_or_path,
|
||||
"pytorch_model.bin.index.json",
|
||||
token=token,
|
||||
)
|
||||
except ValueError: # not continue training, do not have v_head weight
|
||||
except (EntryNotFoundError, LocalEntryNotFoundError, HFValidationError):
|
||||
# not continue training, do not have v_head weight
|
||||
is_resuming_training = False
|
||||
logging.warning(
|
||||
f"A {type(pretrained_model)} model is loaded from '{pretrained_model_name_or_path}', "
|
||||
@ -267,12 +282,17 @@ class PreTrainedModelWrapper(nn.Module):
|
||||
if any([module in k for module in cls.supported_modules]):
|
||||
files_to_download.add(v)
|
||||
is_shared = True
|
||||
|
||||
if is_resuming_training:
|
||||
if is_shared:
|
||||
# download each file and add it to the state_dict
|
||||
state_dict = {}
|
||||
for shard_file in files_to_download:
|
||||
filename = hf_hub_download(pretrained_model_name_or_path, shard_file)
|
||||
filename = hf_hub_download(
|
||||
pretrained_model_name_or_path,
|
||||
shard_file,
|
||||
token=token,
|
||||
)
|
||||
state_dict.update(torch.load(filename, map_location="cpu"))
|
||||
else:
|
||||
state_dict = torch.load(filename, map_location="cpu")
|
||||
@ -290,7 +310,7 @@ class PreTrainedModelWrapper(nn.Module):
|
||||
if not is_peft_model and reward_adapter is not None:
|
||||
raise ValueError("reward_adapter can only be used with a PeftModel. ")
|
||||
elif is_peft_model and reward_adapter is not None:
|
||||
model.add_and_load_reward_modeling_adapter(reward_adapter)
|
||||
model.add_and_load_reward_modeling_adapter(reward_adapter, token=token)
|
||||
model.supports_rm_adapter = True
|
||||
else:
|
||||
model.supports_rm_adapter = False
|
||||
@ -400,7 +420,7 @@ class PreTrainedModelWrapper(nn.Module):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def add_and_load_reward_modeling_adapter(self, adapter_model_id, adapter_name="reward_model_adapter"):
|
||||
def add_and_load_reward_modeling_adapter(self, adapter_model_id, adapter_name="reward_model_adapter", token=None):
|
||||
r"""
|
||||
Add and load a reward modeling adapter. This method can only be used if the
|
||||
model is a `PeftModel` and if you have initialized the model with the `reward_modeling_adapter_id`
|
||||
@ -410,7 +430,11 @@ class PreTrainedModelWrapper(nn.Module):
|
||||
filename = os.path.join(adapter_model_id, "adapter_model.bin")
|
||||
if not os.path.exists(filename):
|
||||
try:
|
||||
local_filename = hf_hub_download(adapter_model_id, "adapter_model.bin")
|
||||
local_filename = hf_hub_download(
|
||||
adapter_model_id,
|
||||
"adapter_model.bin",
|
||||
token=token,
|
||||
)
|
||||
except: # noqa
|
||||
raise ValueError(
|
||||
"Could not find adapter model in the Hub, make sure you have the correct adapter model id."
|
||||
@ -441,7 +465,10 @@ class PreTrainedModelWrapper(nn.Module):
|
||||
num_labels, hidden_dim = score_dict["weight"].shape
|
||||
has_bias = any(["bias" in name for name in adapter_state_dict.keys()])
|
||||
|
||||
self.score = nn.Linear(hidden_dim, num_labels, bias=has_bias).to(self._get_current_device())
|
||||
self.score = nn.Linear(hidden_dim, num_labels, bias=has_bias).to(
|
||||
device=self._get_current_device(),
|
||||
dtype=self.pretrained_model.dtype,
|
||||
)
|
||||
self.score.load_state_dict(score_dict)
|
||||
|
||||
# load the adapter to the model
|
||||
|
644
trl/models/modeling_sd_base.py
Normal file
644
trl/models/modeling_sd_base.py
Normal file
@ -0,0 +1,644 @@
|
||||
# Copyright 2023 DDPO-pytorch authors (Kevin Black), The HuggingFace Team, metric-space. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import contextlib
|
||||
import os
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from diffusers import DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
||||
from diffusers.loaders import AttnProcsLayers
|
||||
from diffusers.models.attention_processor import LoRAAttnProcessor
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import rescale_noise_cfg
|
||||
from diffusers.utils import randn_tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class DDPOPipelineOutput(object):
|
||||
"""
|
||||
Output class for the diffusers pipeline to be finetuned with the DDPO trainer
|
||||
|
||||
Args:
|
||||
images (`torch.Tensor`):
|
||||
The generated images.
|
||||
latents (`List[torch.Tensor]`):
|
||||
The latents used to generate the images.
|
||||
log_probs (`List[torch.Tensor]`):
|
||||
The log probabilities of the latents.
|
||||
|
||||
"""
|
||||
|
||||
images: torch.Tensor
|
||||
latents: torch.Tensor
|
||||
log_probs: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class DDPOSchedulerOutput(object):
|
||||
"""
|
||||
Output class for the diffusers scheduler to be finetuned with the DDPO trainer
|
||||
|
||||
Args:
|
||||
latents (`torch.Tensor`):
|
||||
Predicted sample at the previous timestep. Shape: `(batch_size, num_channels, height, width)`
|
||||
log_probs (`torch.Tensor`):
|
||||
Log probability of the above mentioned sample. Shape: `(batch_size)`
|
||||
"""
|
||||
|
||||
latents: torch.Tensor
|
||||
log_probs: torch.Tensor
|
||||
|
||||
|
||||
class DDPOStableDiffusionPipeline(object):
|
||||
"""
|
||||
Main class for the diffusers pipeline to be finetuned with the DDPO trainer
|
||||
"""
|
||||
|
||||
def __call__(self, *args, **kwargs) -> DDPOPipelineOutput:
|
||||
raise NotImplementedError
|
||||
|
||||
def scheduler_step(self, *args, **kwargs) -> DDPOSchedulerOutput:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def unet(self):
|
||||
"""
|
||||
Returns the 2d U-Net model used for diffusion.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def vae(self):
|
||||
"""
|
||||
Returns the Variational Autoencoder model used from mapping images to and from the latent space
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def tokenizer(self):
|
||||
"""
|
||||
Returns the tokenizer used for tokenizing text inputs
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def scheduler(self):
|
||||
"""
|
||||
Returns the scheduler associated with the pipeline used for the diffusion process
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def text_encoder(self):
|
||||
"""
|
||||
Returns the text encoder used for encoding text inputs
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def autocast(self):
|
||||
"""
|
||||
Returns the autocast context manager
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def set_progress_bar_config(self, *args, **kwargs):
|
||||
"""
|
||||
Sets the progress bar config for the pipeline
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def save_pretrained(self, *args, **kwargs):
|
||||
"""
|
||||
Saves all of the model weights
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_trainable_layers(self, *args, **kwargs):
|
||||
"""
|
||||
Returns the trainable parameters of the pipeline
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def save_checkpoint(self, *args, **kwargs):
|
||||
"""
|
||||
Light wrapper around accelerate's register_save_state_pre_hook which is run before saving state
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def load_checkpoint(self, *args, **kwargs):
|
||||
"""
|
||||
Light wrapper around accelerate's register_lad_state_pre_hook which is run before loading state
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def _left_broadcast(input_tensor, shape):
|
||||
"""
|
||||
As opposed to the default direction of broadcasting (right to left), this function broadcasts
|
||||
from left to right
|
||||
Args:
|
||||
input_tensor (`torch.FloatTensor`): is the tensor to broadcast
|
||||
shape (`Tuple[int]`): is the shape to broadcast to
|
||||
"""
|
||||
input_ndim = input_tensor.ndim
|
||||
if input_ndim > len(shape):
|
||||
raise ValueError(
|
||||
"The number of dimensions of the tensor to broadcast cannot be greater than the length of the shape to broadcast to"
|
||||
)
|
||||
return input_tensor.reshape(input_tensor.shape + (1,) * (len(shape) - input_ndim)).broadcast_to(shape)
|
||||
|
||||
|
||||
def _get_variance(self, timestep, prev_timestep):
|
||||
alpha_prod_t = torch.gather(self.alphas_cumprod, 0, timestep.cpu()).to(timestep.device)
|
||||
alpha_prod_t_prev = torch.where(
|
||||
prev_timestep.cpu() >= 0,
|
||||
self.alphas_cumprod.gather(0, prev_timestep.cpu()),
|
||||
self.final_alpha_cumprod,
|
||||
).to(timestep.device)
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
||||
|
||||
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
|
||||
|
||||
return variance
|
||||
|
||||
|
||||
def scheduler_step(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
timestep: int,
|
||||
sample: torch.FloatTensor,
|
||||
eta: float = 0.0,
|
||||
use_clipped_model_output: bool = False,
|
||||
generator=None,
|
||||
prev_sample: Optional[torch.FloatTensor] = None,
|
||||
) -> DDPOSchedulerOutput:
|
||||
"""
|
||||
|
||||
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
|
||||
process from the learned model outputs (most often the predicted noise).
|
||||
|
||||
Args:
|
||||
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
|
||||
timestep (`int`): current discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
current instance of sample being created by diffusion process.
|
||||
eta (`float`): weight of noise for added noise in diffusion step.
|
||||
use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped
|
||||
predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when
|
||||
`self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would
|
||||
coincide with the one provided as input and `use_clipped_model_output` will have not effect.
|
||||
generator: random number generator.
|
||||
variance_noise (`torch.FloatTensor`): instead of generating noise for the variance using `generator`, we
|
||||
can directly provide the noise for the variance itself. This is useful for methods such as
|
||||
CycleDiffusion. (https://arxiv.org/abs/2210.05559)
|
||||
|
||||
Returns:
|
||||
`DDPOSchedulerOutput`: the predicted sample at the previous timestep and the log probability of the sample
|
||||
"""
|
||||
|
||||
if self.num_inference_steps is None:
|
||||
raise ValueError(
|
||||
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
||||
)
|
||||
|
||||
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
|
||||
# Ideally, read DDIM paper in-detail understanding
|
||||
|
||||
# Notation (<variable name> -> <name in paper>
|
||||
# - pred_noise_t -> e_theta(x_t, t)
|
||||
# - pred_original_sample -> f_theta(x_t, t) or x_0
|
||||
# - std_dev_t -> sigma_t
|
||||
# - eta -> η
|
||||
# - pred_sample_direction -> "direction pointing to x_t"
|
||||
# - pred_prev_sample -> "x_t-1"
|
||||
|
||||
# 1. get previous step value (=t-1)
|
||||
prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
|
||||
# to prevent OOB on gather
|
||||
prev_timestep = torch.clamp(prev_timestep, 0, self.config.num_train_timesteps - 1)
|
||||
|
||||
# 2. compute alphas, betas
|
||||
alpha_prod_t = self.alphas_cumprod.gather(0, timestep.cpu())
|
||||
alpha_prod_t_prev = torch.where(
|
||||
prev_timestep.cpu() >= 0,
|
||||
self.alphas_cumprod.gather(0, prev_timestep.cpu()),
|
||||
self.final_alpha_cumprod,
|
||||
)
|
||||
alpha_prod_t = _left_broadcast(alpha_prod_t, sample.shape).to(sample.device)
|
||||
alpha_prod_t_prev = _left_broadcast(alpha_prod_t_prev, sample.shape).to(sample.device)
|
||||
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
|
||||
# 3. compute predicted original sample from predicted noise also called
|
||||
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||
if self.config.prediction_type == "epsilon":
|
||||
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
||||
pred_epsilon = model_output
|
||||
elif self.config.prediction_type == "sample":
|
||||
pred_original_sample = model_output
|
||||
pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
|
||||
elif self.config.prediction_type == "v_prediction":
|
||||
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
|
||||
pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
|
||||
else:
|
||||
raise ValueError(
|
||||
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
|
||||
" `v_prediction`"
|
||||
)
|
||||
|
||||
# 4. Clip or threshold "predicted x_0"
|
||||
if self.config.thresholding:
|
||||
pred_original_sample = self._threshold_sample(pred_original_sample)
|
||||
elif self.config.clip_sample:
|
||||
pred_original_sample = pred_original_sample.clamp(
|
||||
-self.config.clip_sample_range, self.config.clip_sample_range
|
||||
)
|
||||
|
||||
# 5. compute variance: "sigma_t(η)" -> see formula (16)
|
||||
# σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
|
||||
variance = _get_variance(self, timestep, prev_timestep)
|
||||
std_dev_t = eta * variance ** (0.5)
|
||||
std_dev_t = _left_broadcast(std_dev_t, sample.shape).to(sample.device)
|
||||
|
||||
if use_clipped_model_output:
|
||||
# the pred_epsilon is always re-derived from the clipped x_0 in Glide
|
||||
pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
|
||||
|
||||
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon
|
||||
|
||||
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
||||
prev_sample_mean = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
|
||||
|
||||
if prev_sample is not None and generator is not None:
|
||||
raise ValueError(
|
||||
"Cannot pass both generator and prev_sample. Please make sure that either `generator` or"
|
||||
" `prev_sample` stays `None`."
|
||||
)
|
||||
|
||||
if prev_sample is None:
|
||||
variance_noise = randn_tensor(
|
||||
model_output.shape,
|
||||
generator=generator,
|
||||
device=model_output.device,
|
||||
dtype=model_output.dtype,
|
||||
)
|
||||
prev_sample = prev_sample_mean + std_dev_t * variance_noise
|
||||
|
||||
# log prob of prev_sample given prev_sample_mean and std_dev_t
|
||||
log_prob = (
|
||||
-((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * (std_dev_t**2))
|
||||
- torch.log(std_dev_t)
|
||||
- torch.log(torch.sqrt(2 * torch.as_tensor(np.pi)))
|
||||
)
|
||||
# mean along all but batch dimension
|
||||
log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
|
||||
|
||||
return DDPOSchedulerOutput(prev_sample.type(sample.dtype), log_prob)
|
||||
|
||||
|
||||
# 1. The output type for call is different as the logprobs are now returned
|
||||
# 2. An extra method called `scheduler_step` is added which is used to constraint the scheduler output
|
||||
@torch.no_grad()
|
||||
def pipeline_step(
|
||||
self,
|
||||
prompt: Optional[Union[str, List[str]]] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 7.5,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
guidance_rescale: float = 0.0,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
||||
plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that will be called every `callback_steps` steps during inference. The function will be
|
||||
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||
called at every step.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
|
||||
guidance_rescale (`float`, *optional*, defaults to 0.7):
|
||||
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
|
||||
[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
||||
Guidance rescale factor should fix overexposure when using zero terminal SNR.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
`DDPOPipelineOutput`: The generated image, the predicted latents used to generate the image and the associated log probabilities
|
||||
"""
|
||||
# 0. Default height and width to unet
|
||||
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
||||
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
callback_steps,
|
||||
negative_prompt,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
)
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
# 3. Encode input prompt
|
||||
text_encoder_lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
||||
prompt_embeds = self._encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=text_encoder_lora_scale,
|
||||
)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 5. Prepare latent variables
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 6. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
all_latents = [latents]
|
||||
all_log_probs = []
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
if do_classifier_free_guidance and guidance_rescale > 0.0:
|
||||
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
scheduler_output = scheduler_step(self.scheduler, noise_pred, t, latents, eta)
|
||||
latents = scheduler_output.latents
|
||||
log_prob = scheduler_output.log_probs
|
||||
|
||||
all_latents.append(latents)
|
||||
all_log_probs.append(log_prob)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
if not output_type == "latent":
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
||||
else:
|
||||
image = latents
|
||||
has_nsfw_concept = None
|
||||
|
||||
if has_nsfw_concept is None:
|
||||
do_denormalize = [True] * image.shape[0]
|
||||
else:
|
||||
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
||||
|
||||
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
||||
|
||||
# Offload last model to CPU
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.final_offload_hook.offload()
|
||||
|
||||
return DDPOPipelineOutput(image, all_latents, all_log_probs)
|
||||
|
||||
|
||||
class DefaultDDPOStableDiffusionPipeline(DDPOStableDiffusionPipeline):
|
||||
def __init__(self, pretrained_model_name: str, *, pretrained_model_revision: str = "main", use_lora: bool = True):
|
||||
self.sd_pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
pretrained_model_name, revision=pretrained_model_revision
|
||||
)
|
||||
|
||||
self.use_lora = use_lora
|
||||
self.pretrained_model = pretrained_model_name
|
||||
self.pretrained_revision = pretrained_model_revision
|
||||
|
||||
try:
|
||||
self.sd_pipeline.unet.load_attn_procs(pretrained_model_name, revision=pretrained_model_revision)
|
||||
self.use_lora = True
|
||||
except OSError:
|
||||
if use_lora:
|
||||
warnings.warn(
|
||||
"If you are aware that the pretrained model has no lora weights to it, ignore this message. "
|
||||
"Otherwise please check the if `pytorch_lora_weights.safetensors` exists in the model folder."
|
||||
)
|
||||
|
||||
self.sd_pipeline.scheduler = DDIMScheduler.from_config(self.sd_pipeline.scheduler.config)
|
||||
self.sd_pipeline.safety_checker = None
|
||||
|
||||
# memory optimization
|
||||
self.sd_pipeline.vae.requires_grad_(False)
|
||||
self.sd_pipeline.text_encoder.requires_grad_(False)
|
||||
self.sd_pipeline.unet.requires_grad_(not self.use_lora)
|
||||
|
||||
def __call__(self, *args, **kwargs) -> DDPOPipelineOutput:
|
||||
return pipeline_step(self.sd_pipeline, *args, **kwargs)
|
||||
|
||||
def scheduler_step(self, *args, **kwargs) -> DDPOSchedulerOutput:
|
||||
return scheduler_step(self.sd_pipeline.scheduler, *args, **kwargs)
|
||||
|
||||
@property
|
||||
def unet(self):
|
||||
return self.sd_pipeline.unet
|
||||
|
||||
@property
|
||||
def vae(self):
|
||||
return self.sd_pipeline.vae
|
||||
|
||||
@property
|
||||
def tokenizer(self):
|
||||
return self.sd_pipeline.tokenizer
|
||||
|
||||
@property
|
||||
def scheduler(self):
|
||||
return self.sd_pipeline.scheduler
|
||||
|
||||
@property
|
||||
def text_encoder(self):
|
||||
return self.sd_pipeline.text_encoder
|
||||
|
||||
@property
|
||||
def autocast(self):
|
||||
return contextlib.nullcontext if self.use_lora else None
|
||||
|
||||
def save_pretrained(self, output_dir):
|
||||
if self.use_lora:
|
||||
self.sd_pipeline.unet.save_attn_procs(output_dir)
|
||||
self.sd_pipeline.save_pretrained(output_dir)
|
||||
|
||||
def set_progress_bar_config(self, *args, **kwargs):
|
||||
self.sd_pipeline.set_progress_bar_config(*args, **kwargs)
|
||||
|
||||
def get_trainable_layers(self):
|
||||
if self.use_lora:
|
||||
# Set correct lora layers
|
||||
lora_attn_procs = {}
|
||||
for name in self.sd_pipeline.unet.attn_processors.keys():
|
||||
cross_attention_dim = (
|
||||
None if name.endswith("attn1.processor") else self.sd_pipeline.unet.config.cross_attention_dim
|
||||
)
|
||||
if name.startswith("mid_block"):
|
||||
hidden_size = self.sd_pipeline.unet.config.block_out_channels[-1]
|
||||
elif name.startswith("up_blocks"):
|
||||
block_id = int(name[len("up_blocks.")])
|
||||
hidden_size = list(reversed(self.sd_pipeline.unet.config.block_out_channels))[block_id]
|
||||
elif name.startswith("down_blocks"):
|
||||
block_id = int(name[len("down_blocks.")])
|
||||
hidden_size = self.sd_pipeline.unet.config.block_out_channels[block_id]
|
||||
|
||||
lora_attn_procs[name] = LoRAAttnProcessor(
|
||||
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
|
||||
)
|
||||
self.sd_pipeline.unet.set_attn_processor(lora_attn_procs)
|
||||
return AttnProcsLayers(self.sd_pipeline.unet.attn_processors)
|
||||
else:
|
||||
return self.sd_pipeline.unet
|
||||
|
||||
def save_checkpoint(self, models, weights, output_dir):
|
||||
if len(models) != 1:
|
||||
raise ValueError("Given how the trainable params were set, this should be of length 1")
|
||||
if self.use_lora and isinstance(models[0], AttnProcsLayers):
|
||||
self.sd_pipeline.unet.save_attn_procs(output_dir)
|
||||
elif not self.use_lora and isinstance(models[0], UNet2DConditionModel):
|
||||
models[0].save_pretrained(os.path.join(output_dir, "unet"))
|
||||
else:
|
||||
raise ValueError(f"Unknown model type {type(models[0])}")
|
||||
|
||||
def load_checkpoint(self, models, input_dir):
|
||||
if len(models) != 1:
|
||||
raise ValueError("Given how the trainable params were set, this should be of length 1")
|
||||
if self.use_lora and isinstance(models[0], AttnProcsLayers):
|
||||
tmp_unet = UNet2DConditionModel.from_pretrained(
|
||||
self.pretrained_model,
|
||||
revision=self.pretrained_revision,
|
||||
subfolder="unet",
|
||||
)
|
||||
tmp_unet.load_attn_procs(input_dir)
|
||||
models[0].load_state_dict(AttnProcsLayers(tmp_unet.attn_processors).state_dict())
|
||||
del tmp_unet
|
||||
elif not self.use_lora and isinstance(models[0], UNet2DConditionModel):
|
||||
load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
|
||||
models[0].register_to_config(**load_model.config)
|
||||
models[0].load_state_dict(load_model.state_dict())
|
||||
del load_model
|
||||
else:
|
||||
raise ValueError(f"Unknown model type {type(models[0])}")
|
@ -16,11 +16,25 @@
|
||||
|
||||
# There is a circular import in the PPOTrainer if we let isort sort these
|
||||
# isort: off
|
||||
from .utils import AdaptiveKLController, FixedKLController, ConstantLengthDataset, DataCollatorForCompletionOnlyLM
|
||||
from .utils import (
|
||||
AdaptiveKLController,
|
||||
FixedKLController,
|
||||
ConstantLengthDataset,
|
||||
DataCollatorForCompletionOnlyLM,
|
||||
RunningMoments,
|
||||
disable_dropout_in_model,
|
||||
)
|
||||
|
||||
# isort: on
|
||||
|
||||
from ..import_utils import is_diffusers_available
|
||||
from .base import BaseTrainer
|
||||
from .ddpo_config import DDPOConfig
|
||||
|
||||
|
||||
if is_diffusers_available():
|
||||
from .ddpo_trainer import DDPOTrainer
|
||||
|
||||
from .dpo_trainer import DPOTrainer
|
||||
from .ppo_config import PPOConfig
|
||||
from .ppo_trainer import PPOTrainer
|
||||
|
140
trl/trainer/ddpo_config.py
Normal file
140
trl/trainer/ddpo_config.py
Normal file
@ -0,0 +1,140 @@
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from ..core import flatten_dict
|
||||
from ..import_utils import is_bitsandbytes_available, is_torchvision_available
|
||||
|
||||
|
||||
@dataclass
|
||||
class DDPOConfig(object):
|
||||
"""
|
||||
Configuration class for DDPOTrainer
|
||||
"""
|
||||
|
||||
run_name: Optional[str] = field(
|
||||
default="",
|
||||
metadata={"help": "Run name for wandb logging and checkpoint saving."},
|
||||
)
|
||||
seed: Optional[int] = field(default=42, metadata={"help": "Random seed for reproducibility."})
|
||||
logdir: Optional[str] = field(
|
||||
default="logs",
|
||||
metadata={"help": "Top-level logging directory for checkpoint saving."},
|
||||
)
|
||||
log_with: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Log with either 'wandb' or 'tensorboard', check https://huggingface.co/docs/accelerate/usage_guides/tracking for more details"
|
||||
},
|
||||
)
|
||||
tracker_kwargs: Optional[dict] = field(
|
||||
default_factory=dict,
|
||||
metadata={"help": "Keyword arguments for the tracker (e.g. wandb_project)"},
|
||||
)
|
||||
accelerator_kwargs: Optional[dict] = field(
|
||||
default_factory=dict,
|
||||
metadata={"help": "Keyword arguments for the accelerator"},
|
||||
)
|
||||
project_kwargs: Optional[dict] = field(
|
||||
default_factory=dict,
|
||||
metadata={"help": "Keyword arguments for the accelerator project config (e.g. `logging_dir`)"},
|
||||
)
|
||||
tracker_project_name: Optional[str] = field(
|
||||
default="trl", metadata={"help": "Name of project to use for tracking"}
|
||||
)
|
||||
num_epochs: Optional[int] = field(default=100, metadata={"help": "Number of epochs to train."})
|
||||
save_freq: Optional[int] = field(
|
||||
default=1,
|
||||
metadata={"help": "Number of epochs between saving model checkpoints."},
|
||||
)
|
||||
num_checkpoint_limit: Optional[int] = field(
|
||||
default=5,
|
||||
metadata={"help": "Number of checkpoints to keep before overwriting old ones."},
|
||||
)
|
||||
mixed_precision: Optional[str] = field(default="fp16", metadata={"help": "Mixed precision training."})
|
||||
allow_tf32: Optional[bool] = field(default=True, metadata={"help": "Allow tf32 on Ampere GPUs."})
|
||||
resume_from: Optional[str] = field(default="", metadata={"help": "Resume training from a checkpoint."})
|
||||
sample_num_steps: Optional[int] = field(default=50, metadata={"help": "Number of sampler inference steps."})
|
||||
sample_eta: Optional[float] = field(default=1.0, metadata={"help": "Eta parameter for the DDIM sampler."})
|
||||
sample_guidance_scale: Optional[float] = field(default=5.0, metadata={"help": "Classifier-free guidance weight."})
|
||||
sample_batch_size: Optional[int] = field(
|
||||
default=1, metadata={"help": "Batch size (per GPU!) to use for sampling."}
|
||||
)
|
||||
sample_num_batches_per_epoch: Optional[int] = field(
|
||||
default=2, metadata={"help": "Number of batches to sample per epoch."}
|
||||
)
|
||||
train_batch_size: Optional[int] = field(default=1, metadata={"help": "Batch size (per GPU!) to use for training."})
|
||||
train_use_8bit_adam: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to use the 8bit Adam optimizer from bitsandbytes."},
|
||||
)
|
||||
train_learning_rate: Optional[float] = field(default=3e-4, metadata={"help": "Learning rate."})
|
||||
train_adam_beta1: Optional[float] = field(default=0.9, metadata={"help": "Adam beta1."})
|
||||
train_adam_beta2: Optional[float] = field(default=0.999, metadata={"help": "Adam beta2."})
|
||||
train_adam_weight_decay: Optional[float] = field(default=1e-4, metadata={"help": "Adam weight decay."})
|
||||
train_adam_epsilon: Optional[float] = field(default=1e-8, metadata={"help": "Adam epsilon."})
|
||||
train_gradient_accumulation_steps: Optional[int] = field(
|
||||
default=1, metadata={"help": "Number of gradient accumulation steps."}
|
||||
)
|
||||
train_max_grad_norm: Optional[float] = field(
|
||||
default=1.0, metadata={"help": "Maximum gradient norm for gradient clipping."}
|
||||
)
|
||||
train_num_inner_epochs: Optional[int] = field(
|
||||
default=1, metadata={"help": "Number of inner epochs per outer epoch."}
|
||||
)
|
||||
train_cfg: Optional[bool] = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether or not to use classifier-free guidance during training."},
|
||||
)
|
||||
train_adv_clip_max: Optional[float] = field(default=5, metadata={"help": "Clip advantages to the range."})
|
||||
train_clip_range: Optional[float] = field(default=1e-4, metadata={"help": "The PPO clip range."})
|
||||
train_timestep_fraction: Optional[float] = field(
|
||||
default=1.0, metadata={"help": "The fraction of timesteps to train on."}
|
||||
)
|
||||
|
||||
per_prompt_stat_tracking: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to track statistics for each prompt separately."},
|
||||
)
|
||||
|
||||
per_prompt_stat_tracking_buffer_size: Optional[int] = field(
|
||||
default=16,
|
||||
metadata={"help": "Number of reward values to store in the buffer for each prompt."},
|
||||
)
|
||||
per_prompt_stat_tracking_min_count: Optional[int] = field(
|
||||
default=16,
|
||||
metadata={"help": "The minimum number of reward values to store in the buffer."},
|
||||
)
|
||||
async_reward_computation: Optional[bool] = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to compute rewards asynchronously."},
|
||||
)
|
||||
max_workers: Optional[int] = field(
|
||||
default=2,
|
||||
metadata={"help": "The maximum number of workers to use for async reward computation."},
|
||||
)
|
||||
negative_prompts: Optional[str] = field(
|
||||
default="",
|
||||
metadata={"help": "Comma-separated list of prompts to use as negative examples."},
|
||||
)
|
||||
|
||||
def to_dict(self):
|
||||
output_dict = {}
|
||||
for key, value in self.__dict__.items():
|
||||
output_dict[key] = value
|
||||
return flatten_dict(output_dict)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.log_with not in ["wandb", "tensorboard"]:
|
||||
warnings.warn(
|
||||
("Accelerator tracking only supports image logging if `log_with` is set to 'wandb' or 'tensorboard'.")
|
||||
)
|
||||
|
||||
if self.log_with == "wandb" and not is_torchvision_available():
|
||||
warnings.warn("Wandb image logging requires torchvision to be installed")
|
||||
|
||||
if self.train_use_8bit_adam and not is_bitsandbytes_available():
|
||||
raise ImportError(
|
||||
"You need to install bitsandbytes to use 8bit Adam. "
|
||||
"You can install it with `pip install bitsandbytes`."
|
||||
)
|
576
trl/trainer/ddpo_trainer.py
Normal file
576
trl/trainer/ddpo_trainer.py
Normal file
@ -0,0 +1,576 @@
|
||||
# Copyright 2023 DDPO-pytorch authors (Kevin Black), metric-space, The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from concurrent import futures
|
||||
from typing import Any, Callable, Optional, Tuple
|
||||
from warnings import warn
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import ProjectConfiguration, set_seed
|
||||
|
||||
from ..models import DDPOStableDiffusionPipeline
|
||||
from . import BaseTrainer, DDPOConfig
|
||||
from .utils import PerPromptStatTracker
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DDPOTrainer(BaseTrainer):
|
||||
"""
|
||||
The DDPOTrainer uses Deep Diffusion Policy Optimization to optimise diffusion models.
|
||||
Note, this trainer is heavily inspired by the work here: https://github.com/kvablack/ddpo-pytorch
|
||||
As of now only Stable Diffusion based pipelines are supported
|
||||
|
||||
Attributes:
|
||||
**config** (`DDPOConfig`) -- Configuration object for DDPOTrainer. Check the documentation of `PPOConfig` for more
|
||||
details.
|
||||
**reward_function** (Callable[[torch.Tensor, Tuple[str], Tuple[Any]], torch.Tensor]) -- Reward function to be used
|
||||
**prompt_function** (Callable[[], Tuple[str, Any]]) -- Function to generate prompts to guide model
|
||||
**sd_pipeline** (`DDPOStableDiffusionPipeline`) -- Stable Diffusion pipeline to be used for training.
|
||||
**image_samples_hook** (Optional[Callable[[Any, Any, Any], Any]]) -- Hook to be called to log images
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: DDPOConfig,
|
||||
reward_function: Callable[[torch.Tensor, Tuple[str], Tuple[Any]], torch.Tensor],
|
||||
prompt_function: Callable[[], Tuple[str, Any]],
|
||||
sd_pipeline: DDPOStableDiffusionPipeline,
|
||||
image_samples_hook: Optional[Callable[[Any, Any, Any], Any]] = None,
|
||||
):
|
||||
if image_samples_hook is None:
|
||||
warn("No image_samples_hook provided; no images will be logged")
|
||||
|
||||
self.prompt_fn = prompt_function
|
||||
self.reward_fn = reward_function
|
||||
self.config = config
|
||||
self.image_samples_callback = image_samples_hook
|
||||
|
||||
accelerator_project_config = ProjectConfiguration(**self.config.project_kwargs)
|
||||
|
||||
if self.config.resume_from:
|
||||
self.config.resume_from = os.path.normpath(os.path.expanduser(self.config.resume_from))
|
||||
if "checkpoint_" not in os.path.basename(self.config.resume_from):
|
||||
# get the most recent checkpoint in this directory
|
||||
checkpoints = list(
|
||||
filter(
|
||||
lambda x: "checkpoint_" in x,
|
||||
os.listdir(self.config.resume_from),
|
||||
)
|
||||
)
|
||||
if len(checkpoints) == 0:
|
||||
raise ValueError(f"No checkpoints found in {self.config.resume_from}")
|
||||
checkpoint_numbers = sorted([int(x.split("_")[-1]) for x in checkpoints])
|
||||
self.config.resume_from = os.path.join(
|
||||
self.config.resume_from,
|
||||
f"checkpoint_{checkpoint_numbers[-1]}",
|
||||
)
|
||||
|
||||
accelerator_project_config.iteration = checkpoint_numbers[-1] + 1
|
||||
|
||||
# number of timesteps within each trajectory to train on
|
||||
self.num_train_timesteps = int(self.config.sample_num_steps * self.config.train_timestep_fraction)
|
||||
|
||||
self.accelerator = Accelerator(
|
||||
log_with=self.config.log_with,
|
||||
mixed_precision=self.config.mixed_precision,
|
||||
project_config=accelerator_project_config,
|
||||
# we always accumulate gradients across timesteps; we want config.train.gradient_accumulation_steps to be the
|
||||
# number of *samples* we accumulate across, so we need to multiply by the number of training timesteps to get
|
||||
# the total number of optimizer steps to accumulate across.
|
||||
gradient_accumulation_steps=self.config.train_gradient_accumulation_steps * self.num_train_timesteps,
|
||||
**self.config.accelerator_kwargs,
|
||||
)
|
||||
|
||||
is_okay, message = self._config_check()
|
||||
if not is_okay:
|
||||
raise ValueError(message)
|
||||
|
||||
is_using_tensorboard = config.log_with is not None and config.log_with == "tensorboard"
|
||||
|
||||
if self.accelerator.is_main_process:
|
||||
self.accelerator.init_trackers(
|
||||
self.config.tracker_project_name,
|
||||
config=dict(ddpo_trainer_config=config.to_dict()) if not is_using_tensorboard else config.to_dict(),
|
||||
init_kwargs=self.config.tracker_kwargs,
|
||||
)
|
||||
|
||||
logger.info(f"\n{config}")
|
||||
|
||||
set_seed(self.config.seed, device_specific=True)
|
||||
|
||||
self.sd_pipeline = sd_pipeline
|
||||
|
||||
self.sd_pipeline.set_progress_bar_config(
|
||||
position=1,
|
||||
disable=not self.accelerator.is_local_main_process,
|
||||
leave=False,
|
||||
desc="Timestep",
|
||||
dynamic_ncols=True,
|
||||
)
|
||||
|
||||
# For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
|
||||
# as these weights are only used for inference, keeping weights in full precision is not required.
|
||||
if self.accelerator.mixed_precision == "fp16":
|
||||
inference_dtype = torch.float16
|
||||
elif self.accelerator.mixed_precision == "bf16":
|
||||
inference_dtype = torch.bfloat16
|
||||
else:
|
||||
inference_dtype = torch.float32
|
||||
|
||||
self.sd_pipeline.vae.to(self.accelerator.device, dtype=inference_dtype)
|
||||
self.sd_pipeline.text_encoder.to(self.accelerator.device, dtype=inference_dtype)
|
||||
self.sd_pipeline.unet.to(self.accelerator.device, dtype=inference_dtype)
|
||||
|
||||
trainable_layers = self.sd_pipeline.get_trainable_layers()
|
||||
|
||||
self.accelerator.register_save_state_pre_hook(self._save_model_hook)
|
||||
self.accelerator.register_load_state_pre_hook(self._load_model_hook)
|
||||
|
||||
# Enable TF32 for faster training on Ampere GPUs,
|
||||
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
||||
if self.config.allow_tf32:
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
self.optimizer = self._setup_optimizer(trainable_layers.parameters())
|
||||
|
||||
self.neg_prompt_embed = self.sd_pipeline.text_encoder(
|
||||
self.sd_pipeline.tokenizer(
|
||||
[""] if self.config.negative_prompts is None else self.config.negative_prompts,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=self.sd_pipeline.tokenizer.model_max_length,
|
||||
).input_ids.to(self.accelerator.device)
|
||||
)[0]
|
||||
|
||||
if config.per_prompt_stat_tracking:
|
||||
self.stat_tracker = PerPromptStatTracker(
|
||||
config.per_prompt_stat_tracking_buffer_size,
|
||||
config.per_prompt_stat_tracking_min_count,
|
||||
)
|
||||
|
||||
# NOTE: for some reason, autocast is necessary for non-lora training but for lora training it isn't necessary and it uses
|
||||
# more memory
|
||||
self.autocast = self.sd_pipeline.autocast or self.accelerator.autocast
|
||||
|
||||
self.trainable_layers, self.optimizer = self.accelerator.prepare(trainable_layers, self.optimizer)
|
||||
|
||||
if self.config.async_reward_computation:
|
||||
self.executor = futures.ThreadPoolExecutor(max_workers=config.max_workers)
|
||||
|
||||
if config.resume_from:
|
||||
logger.info(f"Resuming from {config.resume_from}")
|
||||
self.accelerator.load_state(config.resume_from)
|
||||
self.first_epoch = int(config.resume_from.split("_")[-1]) + 1
|
||||
else:
|
||||
self.first_epoch = 0
|
||||
|
||||
def compute_rewards(self, prompt_image_pairs, is_async=False):
|
||||
if not is_async:
|
||||
rewards = []
|
||||
for images, prompts, prompt_metadata in prompt_image_pairs:
|
||||
reward, reward_metadata = self.reward_fn(images, prompts, prompt_metadata)
|
||||
rewards.append(
|
||||
(
|
||||
torch.as_tensor(reward, device=self.accelerator.device),
|
||||
reward_metadata,
|
||||
)
|
||||
)
|
||||
else:
|
||||
rewards = self.executor.map(lambda x: self.reward_fn(*x), prompt_image_pairs)
|
||||
rewards = [
|
||||
(torch.as_tensor(reward.result(), device=self.accelerator.device), reward_metadata.result())
|
||||
for reward, reward_metadata in rewards
|
||||
]
|
||||
|
||||
return zip(*rewards)
|
||||
|
||||
def step(self, epoch: int, global_step: int):
|
||||
"""
|
||||
Perform a single step of training.
|
||||
|
||||
Args:
|
||||
epoch (int): The current epoch.
|
||||
global_step (int): The current global step.
|
||||
|
||||
Side Effects:
|
||||
- Model weights are updated
|
||||
- Logs the statistics to the accelerator trackers.
|
||||
- If `self.image_samples_callback` is not None, it will be called with the prompt_image_pairs, global_step, and the accelerator tracker.
|
||||
|
||||
Returns:
|
||||
global_step (int): The updated global step.
|
||||
|
||||
"""
|
||||
samples, prompt_image_data = self._generate_samples(
|
||||
iterations=self.config.sample_num_batches_per_epoch,
|
||||
batch_size=self.config.sample_batch_size,
|
||||
)
|
||||
|
||||
# collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...)
|
||||
samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()}
|
||||
rewards, rewards_metadata = self.compute_rewards(
|
||||
prompt_image_data, is_async=self.config.async_reward_computation
|
||||
)
|
||||
|
||||
for i, image_data in enumerate(prompt_image_data):
|
||||
image_data.extend([rewards[i], rewards_metadata[i]])
|
||||
|
||||
if self.image_samples_callback is not None:
|
||||
self.image_samples_callback(prompt_image_data, global_step, self.accelerator.trackers[0])
|
||||
|
||||
rewards = torch.cat(rewards)
|
||||
rewards = self.accelerator.gather(rewards).cpu().numpy()
|
||||
|
||||
self.accelerator.log(
|
||||
{
|
||||
"reward": rewards,
|
||||
"epoch": epoch,
|
||||
"reward_mean": rewards.mean(),
|
||||
"reward_std": rewards.std(),
|
||||
},
|
||||
step=global_step,
|
||||
)
|
||||
|
||||
if self.config.per_prompt_stat_tracking:
|
||||
# gather the prompts across processes
|
||||
prompt_ids = self.accelerator.gather(samples["prompt_ids"]).cpu().numpy()
|
||||
prompts = self.sd_pipeline.tokenizer.batch_decode(prompt_ids, skip_special_tokens=True)
|
||||
advantages = self.stat_tracker.update(prompts, rewards)
|
||||
else:
|
||||
advantages = (rewards - rewards.mean()) / (rewards.std() + 1e-8)
|
||||
|
||||
# ungather advantages; keep the entries corresponding to the samples on this process
|
||||
samples["advantages"] = (
|
||||
torch.as_tensor(advantages)
|
||||
.reshape(self.accelerator.num_processes, -1)[self.accelerator.process_index]
|
||||
.to(self.accelerator.device)
|
||||
)
|
||||
|
||||
del samples["prompt_ids"]
|
||||
|
||||
total_batch_size, num_timesteps = samples["timesteps"].shape
|
||||
|
||||
for inner_epoch in range(self.config.train_num_inner_epochs):
|
||||
# shuffle samples along batch dimension
|
||||
perm = torch.randperm(total_batch_size, device=self.accelerator.device)
|
||||
samples = {k: v[perm] for k, v in samples.items()}
|
||||
|
||||
# shuffle along time dimension independently for each sample
|
||||
# still trying to understand the code below
|
||||
perms = torch.stack(
|
||||
[torch.randperm(num_timesteps, device=self.accelerator.device) for _ in range(total_batch_size)]
|
||||
)
|
||||
|
||||
for key in ["timesteps", "latents", "next_latents", "log_probs"]:
|
||||
samples[key] = samples[key][
|
||||
torch.arange(total_batch_size, device=self.accelerator.device)[:, None],
|
||||
perms,
|
||||
]
|
||||
|
||||
original_keys = samples.keys()
|
||||
original_values = samples.values()
|
||||
# rebatch them as user defined train_batch_size is different from sample_batch_size
|
||||
reshaped_values = [v.reshape(-1, self.config.train_batch_size, *v.shape[1:]) for v in original_values]
|
||||
|
||||
# Transpose the list of original values
|
||||
transposed_values = zip(*reshaped_values)
|
||||
# Create new dictionaries for each row of transposed values
|
||||
samples_batched = [dict(zip(original_keys, row_values)) for row_values in transposed_values]
|
||||
|
||||
self.sd_pipeline.unet.train()
|
||||
global_step = self._train_batched_samples(inner_epoch, epoch, global_step, samples_batched)
|
||||
# ensure optimization step at the end of the inner epoch
|
||||
if not self.accelerator.sync_gradients:
|
||||
raise ValueError(
|
||||
"Optimization step should have been performed by this point. Please check calculated gradient accumulation settings."
|
||||
)
|
||||
|
||||
if epoch != 0 and epoch % self.config.save_freq == 0 and self.accelerator.is_main_process:
|
||||
self.accelerator.save_state()
|
||||
|
||||
return global_step
|
||||
|
||||
def calculate_loss(self, latents, timesteps, next_latents, log_probs, advantages, embeds):
|
||||
"""
|
||||
Calculate the loss for a batch of an unpacked sample
|
||||
|
||||
Args:
|
||||
latents (torch.Tensor):
|
||||
The latents sampled from the diffusion model, shape: [batch_size, num_steps, ...]
|
||||
timesteps (torch.Tensor):
|
||||
The timesteps sampled from the diffusion model, shape: [batch_size]
|
||||
next_latents (torch.Tensor):
|
||||
The next latents sampled from the diffusion model, shape: [batch_size, num_steps, ...]
|
||||
log_probs (torch.Tensor):
|
||||
The log probabilities of the latents, shape: [batch_size]
|
||||
advantages (torch.Tensor):
|
||||
The advantages of the latents, shape: [batch_size]
|
||||
embeds (torch.Tensor):
|
||||
The embeddings of the prompts, shape: [2*batch_size or batch_size, ...]
|
||||
Note: the "or" is because if train_cfg is True, the expectation is that negative prompts are concatenated to the embeds
|
||||
|
||||
Returns:
|
||||
loss (torch.Tensor), approx_kl (torch.Tensor), clipfrac (torch.Tensor)
|
||||
(all of these are of shape (1,))
|
||||
"""
|
||||
with self.autocast():
|
||||
if self.config.train_cfg:
|
||||
noise_pred = self.sd_pipeline.unet(
|
||||
torch.cat([latents] * 2),
|
||||
torch.cat([timesteps] * 2),
|
||||
embeds,
|
||||
).sample
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.config.sample_guidance_scale * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
else:
|
||||
noise_pred = self.sd_pipeline.unet(
|
||||
latents,
|
||||
timesteps,
|
||||
embeds,
|
||||
).sample
|
||||
# compute the log prob of next_latents given latents under the current model
|
||||
|
||||
scheduler_step_output = self.sd_pipeline.scheduler_step(
|
||||
noise_pred,
|
||||
timesteps,
|
||||
latents,
|
||||
eta=self.config.sample_eta,
|
||||
prev_sample=next_latents,
|
||||
)
|
||||
|
||||
log_prob = scheduler_step_output.log_probs
|
||||
|
||||
advantages = torch.clamp(
|
||||
advantages,
|
||||
-self.config.train_adv_clip_max,
|
||||
self.config.train_adv_clip_max,
|
||||
)
|
||||
|
||||
ratio = torch.exp(log_prob - log_probs)
|
||||
|
||||
loss = self.loss(advantages, self.config.train_clip_range, ratio)
|
||||
|
||||
approx_kl = 0.5 * torch.mean((log_prob - log_probs) ** 2)
|
||||
|
||||
clipfrac = torch.mean((torch.abs(ratio - 1.0) > self.config.train_clip_range).float())
|
||||
|
||||
return loss, approx_kl, clipfrac
|
||||
|
||||
def loss(
|
||||
self,
|
||||
advantages: torch.Tensor,
|
||||
clip_range: float,
|
||||
ratio: torch.Tensor,
|
||||
):
|
||||
unclipped_loss = -advantages * ratio
|
||||
clipped_loss = -advantages * torch.clamp(
|
||||
ratio,
|
||||
1.0 - clip_range,
|
||||
1.0 + clip_range,
|
||||
)
|
||||
return torch.mean(torch.maximum(unclipped_loss, clipped_loss))
|
||||
|
||||
def _setup_optimizer(self, trainable_layers_parameters):
|
||||
if self.config.train_use_8bit_adam:
|
||||
import bitsandbytes
|
||||
|
||||
optimizer_cls = bitsandbytes.optim.AdamW8bit
|
||||
else:
|
||||
optimizer_cls = torch.optim.AdamW
|
||||
|
||||
return optimizer_cls(
|
||||
trainable_layers_parameters,
|
||||
lr=self.config.train_learning_rate,
|
||||
betas=(self.config.train_adam_beta1, self.config.train_adam_beta2),
|
||||
weight_decay=self.config.train_adam_weight_decay,
|
||||
eps=self.config.train_adam_epsilon,
|
||||
)
|
||||
|
||||
def _save_model_hook(self, models, weights, output_dir):
|
||||
self.sd_pipeline.save_checkpoint(models, weights, output_dir)
|
||||
weights.pop() # ensures that accelerate doesn't try to handle saving of the model
|
||||
|
||||
def _load_model_hook(self, models, input_dir):
|
||||
self.sd_pipeline.load_checkpoint(models, input_dir)
|
||||
models.pop() # ensures that accelerate doesn't try to handle loading of the model
|
||||
|
||||
def _generate_samples(self, iterations, batch_size):
|
||||
"""
|
||||
Generate samples from the model
|
||||
|
||||
Args:
|
||||
iterations (int): Number of iterations to generate samples for
|
||||
batch_size (int): Batch size to use for sampling
|
||||
|
||||
Returns:
|
||||
samples (List[Dict[str, torch.Tensor]]), prompt_image_pairs (List[List[Any]])
|
||||
"""
|
||||
samples = []
|
||||
prompt_image_pairs = []
|
||||
self.sd_pipeline.unet.eval()
|
||||
|
||||
sample_neg_prompt_embeds = self.neg_prompt_embed.repeat(batch_size, 1, 1)
|
||||
|
||||
for _ in range(iterations):
|
||||
prompts, prompt_metadata = zip(*[self.prompt_fn() for _ in range(batch_size)])
|
||||
|
||||
prompt_ids = self.sd_pipeline.tokenizer(
|
||||
prompts,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=self.sd_pipeline.tokenizer.model_max_length,
|
||||
).input_ids.to(self.accelerator.device)
|
||||
prompt_embeds = self.sd_pipeline.text_encoder(prompt_ids)[0]
|
||||
|
||||
with self.autocast():
|
||||
sd_output = self.sd_pipeline(
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=sample_neg_prompt_embeds,
|
||||
num_inference_steps=self.config.sample_num_steps,
|
||||
guidance_scale=self.config.sample_guidance_scale,
|
||||
eta=self.config.sample_eta,
|
||||
output_type="pt",
|
||||
)
|
||||
|
||||
images = sd_output.images
|
||||
latents = sd_output.latents
|
||||
log_probs = sd_output.log_probs
|
||||
|
||||
latents = torch.stack(latents, dim=1) # (batch_size, num_steps + 1, ...)
|
||||
log_probs = torch.stack(log_probs, dim=1) # (batch_size, num_steps, 1)
|
||||
timesteps = self.sd_pipeline.scheduler.timesteps.repeat(batch_size, 1) # (batch_size, num_steps)
|
||||
|
||||
samples.append(
|
||||
{
|
||||
"prompt_ids": prompt_ids,
|
||||
"prompt_embeds": prompt_embeds,
|
||||
"timesteps": timesteps,
|
||||
"latents": latents[:, :-1], # each entry is the latent before timestep t
|
||||
"next_latents": latents[:, 1:], # each entry is the latent after timestep t
|
||||
"log_probs": log_probs,
|
||||
"negative_prompt_embeds": sample_neg_prompt_embeds,
|
||||
}
|
||||
)
|
||||
prompt_image_pairs.append([images, prompts, prompt_metadata])
|
||||
|
||||
return samples, prompt_image_pairs
|
||||
|
||||
def _train_batched_samples(self, inner_epoch, epoch, global_step, batched_samples):
|
||||
"""
|
||||
Train on a batch of samples. Main training segment
|
||||
|
||||
Args:
|
||||
inner_epoch (int): The current inner epoch
|
||||
epoch (int): The current epoch
|
||||
global_step (int): The current global step
|
||||
batched_samples (List[Dict[str, torch.Tensor]]): The batched samples to train on
|
||||
|
||||
Side Effects:
|
||||
- Model weights are updated
|
||||
- Logs the statistics to the accelerator trackers.
|
||||
|
||||
Returns:
|
||||
global_step (int): The updated global step
|
||||
"""
|
||||
info = defaultdict(list)
|
||||
for i, sample in enumerate(batched_samples):
|
||||
if self.config.train_cfg:
|
||||
# concat negative prompts to sample prompts to avoid two forward passes
|
||||
embeds = torch.cat([sample["negative_prompt_embeds"], sample["prompt_embeds"]])
|
||||
else:
|
||||
embeds = sample["prompt_embeds"]
|
||||
|
||||
for j in range(self.num_train_timesteps):
|
||||
with self.accelerator.accumulate(self.sd_pipeline.unet):
|
||||
loss, approx_kl, clipfrac = self.calculate_loss(
|
||||
sample["latents"][:, j],
|
||||
sample["timesteps"][:, j],
|
||||
sample["next_latents"][:, j],
|
||||
sample["log_probs"][:, j],
|
||||
sample["advantages"],
|
||||
embeds,
|
||||
)
|
||||
info["approx_kl"].append(approx_kl)
|
||||
info["clipfrac"].append(clipfrac)
|
||||
info["loss"].append(loss)
|
||||
|
||||
self.accelerator.backward(loss)
|
||||
if self.accelerator.sync_gradients:
|
||||
self.accelerator.clip_grad_norm_(
|
||||
self.trainable_layers.parameters(),
|
||||
self.config.train_max_grad_norm,
|
||||
)
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if self.accelerator.sync_gradients:
|
||||
# log training-related stuff
|
||||
info = {k: torch.mean(torch.stack(v)) for k, v in info.items()}
|
||||
info = self.accelerator.reduce(info, reduction="mean")
|
||||
info.update({"epoch": epoch, "inner_epoch": inner_epoch})
|
||||
self.accelerator.log(info, step=global_step)
|
||||
global_step += 1
|
||||
info = defaultdict(list)
|
||||
return global_step
|
||||
|
||||
def _config_check(self) -> Tuple[bool, str]:
|
||||
samples_per_epoch = (
|
||||
self.config.sample_batch_size * self.accelerator.num_processes * self.config.sample_num_batches_per_epoch
|
||||
)
|
||||
total_train_batch_size = (
|
||||
self.config.train_batch_size
|
||||
* self.accelerator.num_processes
|
||||
* self.config.train_gradient_accumulation_steps
|
||||
)
|
||||
|
||||
if not self.config.sample_batch_size >= self.config.train_batch_size:
|
||||
return (
|
||||
False,
|
||||
f"Sample batch size ({self.config.sample_batch_size}) must be greater than or equal to the train batch size ({self.config.train_batch_size})",
|
||||
)
|
||||
if not self.config.sample_batch_size % self.config.train_batch_size == 0:
|
||||
return (
|
||||
False,
|
||||
f"Sample batch size ({self.config.sample_batch_size}) must be divisible by the train batch size ({self.config.train_batch_size})",
|
||||
)
|
||||
if not samples_per_epoch % total_train_batch_size == 0:
|
||||
return (
|
||||
False,
|
||||
f"Number of samples per epoch ({samples_per_epoch}) must be divisible by the total train batch size ({total_train_batch_size})",
|
||||
)
|
||||
return True, ""
|
||||
|
||||
def train(self, epochs: Optional[int] = None):
|
||||
"""
|
||||
Train the model for a given number of epochs
|
||||
"""
|
||||
global_step = 0
|
||||
if epochs is None:
|
||||
epochs = self.config.num_epochs
|
||||
for epoch in range(self.first_epoch, epochs):
|
||||
global_step = self.step(epoch, global_step)
|
||||
|
||||
def _save_pretrained(self, save_directory):
|
||||
self.sd_pipeline.save_pretrained(save_directory)
|
@ -24,11 +24,12 @@ from transformers import DataCollator, PreTrainedModel, PreTrainedTokenizerBase,
|
||||
from transformers.trainer_callback import TrainerCallback
|
||||
|
||||
from ..import_utils import is_peft_available
|
||||
from .utils import DPODataCollatorWithPadding, pad_to_length
|
||||
from ..models import create_reference_model
|
||||
from .utils import DPODataCollatorWithPadding, disable_dropout_in_model, pad_to_length
|
||||
|
||||
|
||||
if is_peft_available():
|
||||
from peft import get_peft_model, prepare_model_for_int8_training
|
||||
from peft import PeftModel, get_peft_model, prepare_model_for_int8_training
|
||||
|
||||
|
||||
class DPOTrainer(Trainer):
|
||||
@ -39,7 +40,8 @@ class DPOTrainer(Trainer):
|
||||
model (`transformers.PreTrainedModel`):
|
||||
The model to train, preferably an `AutoModelForSequenceClassification`.
|
||||
ref_model (`PreTrainedModelWrapper`):
|
||||
Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss.
|
||||
Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no
|
||||
reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized.
|
||||
beta (`float`, defaults to 0.1):
|
||||
The beta factor in DPO loss. Higher beta means less divergence from the initial policy.
|
||||
args (`transformers.TrainingArguments`):
|
||||
@ -73,12 +75,14 @@ class DPOTrainer(Trainer):
|
||||
The maximum length of the prompt. This argument is required if you want to use the default data collator.
|
||||
peft_config (`Dict`, defaults to `None`):
|
||||
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
|
||||
disable_dropout (`bool`, defaults to `True`):
|
||||
Whether or not to disable dropouts in `model` and `ref_model`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Union[PreTrainedModel, nn.Module] = None,
|
||||
ref_model: Union[PreTrainedModel, nn.Module] = None,
|
||||
ref_model: Optional[Union[PreTrainedModel, nn.Module]] = None,
|
||||
beta: float = 0.1,
|
||||
args: TrainingArguments = None,
|
||||
data_collator: Optional[DataCollator] = None,
|
||||
@ -98,6 +102,7 @@ class DPOTrainer(Trainer):
|
||||
max_length: Optional[int] = None,
|
||||
max_prompt_length: Optional[int] = None,
|
||||
peft_config: Optional[Dict] = None,
|
||||
disable_dropout: bool = True,
|
||||
):
|
||||
if not is_peft_available() and peft_config is not None:
|
||||
raise ValueError(
|
||||
@ -108,6 +113,16 @@ class DPOTrainer(Trainer):
|
||||
model = prepare_model_for_int8_training(model)
|
||||
model = get_peft_model(model, peft_config)
|
||||
|
||||
self.is_peft_model = is_peft_available() and isinstance(model, PeftModel)
|
||||
|
||||
if ref_model:
|
||||
self.ref_model = ref_model
|
||||
elif self.is_peft_model:
|
||||
# The `model` with adapters turned off will be used as the reference model
|
||||
self.ref_model = None
|
||||
else:
|
||||
self.ref_model = create_reference_model(model)
|
||||
|
||||
if data_collator is None:
|
||||
if tokenizer is None:
|
||||
raise ValueError(
|
||||
@ -150,11 +165,15 @@ class DPOTrainer(Trainer):
|
||||
else:
|
||||
self.use_dpo_data_collator = False
|
||||
|
||||
if disable_dropout:
|
||||
disable_dropout_in_model(model)
|
||||
if self.ref_model is not None:
|
||||
disable_dropout_in_model(self.ref_model)
|
||||
|
||||
self.label_pad_token_id = label_pad_token_id
|
||||
self.padding_value = padding_value
|
||||
|
||||
self.beta = beta
|
||||
self.ref_model = ref_model
|
||||
|
||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||
|
||||
@ -172,14 +191,19 @@ class DPOTrainer(Trainer):
|
||||
preprocess_logits_for_metrics,
|
||||
)
|
||||
|
||||
# Since we inherit from trainer we always have access to an accelerator
|
||||
if hasattr(self, "accelerator"):
|
||||
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||
else:
|
||||
if not hasattr(self, "accelerator"):
|
||||
raise AttributeError(
|
||||
"Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
|
||||
)
|
||||
|
||||
if self.ref_model is None:
|
||||
if not hasattr(self.accelerator.unwrap_model(self.model), "disable_adapter"):
|
||||
raise ValueError(
|
||||
"You are using a `peft` version that does not support `disable_adapter`. Please update your `peft` version to the latest version."
|
||||
)
|
||||
else:
|
||||
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||
|
||||
def concatenated_inputs(self, batch: Dict[str, Union[List, torch.LongTensor]]) -> Dict[str, torch.LongTensor]:
|
||||
"""Concatenate the chosen and rejected inputs into a single tensor.
|
||||
|
||||
@ -319,12 +343,21 @@ class DPOTrainer(Trainer):
|
||||
policy_rejected_logits,
|
||||
) = self.concatenated_forward(model, batch)
|
||||
with torch.no_grad():
|
||||
(
|
||||
reference_chosen_logps,
|
||||
reference_rejected_logps,
|
||||
_,
|
||||
_,
|
||||
) = self.concatenated_forward(self.ref_model, batch)
|
||||
if self.ref_model is None:
|
||||
with self.accelerator.unwrap_model(self.model).disable_adapter():
|
||||
(
|
||||
reference_chosen_logps,
|
||||
reference_rejected_logps,
|
||||
_,
|
||||
_,
|
||||
) = self.concatenated_forward(self.model, batch)
|
||||
else:
|
||||
(
|
||||
reference_chosen_logps,
|
||||
reference_rejected_logps,
|
||||
_,
|
||||
_,
|
||||
) = self.concatenated_forward(self.ref_model, batch)
|
||||
|
||||
losses, chosen_rewards, rejected_rewards = self.dpo_loss(
|
||||
policy_chosen_logps,
|
||||
@ -378,13 +411,23 @@ class DPOTrainer(Trainer):
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
)
|
||||
|
||||
reference_output = self.ref_model.generate(
|
||||
batch["prompt_input_ids"],
|
||||
attention_mask=batch["prompt_attention_mask"],
|
||||
max_length=self.config.max_length,
|
||||
do_sample=True,
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
)
|
||||
if self.ref_model is None:
|
||||
with self.accelerator.unwrap_model(self.model).disable_adapter():
|
||||
reference_output = self.model.generate(
|
||||
batch["prompt_input_ids"],
|
||||
attention_mask=batch["prompt_attention_mask"],
|
||||
max_length=self.config.max_length,
|
||||
do_sample=True,
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
)
|
||||
else:
|
||||
reference_output = self.ref_model.generate(
|
||||
batch["prompt_input_ids"],
|
||||
attention_mask=batch["prompt_attention_mask"],
|
||||
max_length=self.config.max_length,
|
||||
do_sample=True,
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
)
|
||||
|
||||
policy_output = pad_to_length(policy_output, self.config.max_length, self.tokenizer.pad_token_id)
|
||||
policy_output_decoded = self.tokenizer.batch_decode(policy_output, skip_special_tokens=True)
|
||||
|
@ -21,6 +21,8 @@ from typing import Optional
|
||||
import numpy as np
|
||||
import requests
|
||||
|
||||
from trl.trainer.utils import exact_div
|
||||
|
||||
from ..core import flatten_dict
|
||||
|
||||
|
||||
@ -162,6 +164,11 @@ class PPOConfig(object):
|
||||
ratio_threshold: Optional[float] = field(
|
||||
default=10.0, metadata={"help": "Skip mini-batches with high PPO ratios that can cause loss spikes"}
|
||||
)
|
||||
use_score_scaling: Optional[bool] = field(default=False, metadata={"help": "Use score scaling"})
|
||||
use_score_norm: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Use score normalization. Only applicable if use_score_scaling is True"}
|
||||
)
|
||||
score_clip: Optional[float] = field(default=None, metadata={"help": "Score clipping"})
|
||||
|
||||
def __post_init__(self):
|
||||
if self.forward_batch_size is not None:
|
||||
@ -170,6 +177,15 @@ class PPOConfig(object):
|
||||
)
|
||||
self.mini_batch_size = self.forward_batch_size
|
||||
|
||||
self.backward_batch_size = self.mini_batch_size * self.gradient_accumulation_steps
|
||||
exact_div(
|
||||
self.batch_size,
|
||||
self.backward_batch_size,
|
||||
"`batch_size`",
|
||||
"`mini_batch_size * gradient_accumulation_steps`",
|
||||
"`batch_size` must be a multiple of `mini_batch_size * gradient_accumulation_steps`",
|
||||
)
|
||||
|
||||
# check if wandb is installed
|
||||
if self.log_with == "wandb":
|
||||
# raise error if wandb is not installed
|
||||
|
@ -53,7 +53,7 @@ from ..core import (
|
||||
)
|
||||
from ..import_utils import is_torch_greater_2_0
|
||||
from ..models import SUPPORTED_ARCHITECTURES, PreTrainedModelWrapper, create_reference_model
|
||||
from . import AdaptiveKLController, BaseTrainer, FixedKLController, PPOConfig
|
||||
from . import AdaptiveKLController, BaseTrainer, FixedKLController, PPOConfig, RunningMoments
|
||||
|
||||
|
||||
MODEL_CARD_TEMPLATE = """---
|
||||
@ -66,7 +66,7 @@ tags:
|
||||
|
||||
# {model_name}
|
||||
|
||||
This is a [TRL language model](https://github.com/lvwerra/trl) that has been fine-tuned with reinforcement learning to
|
||||
This is a [TRL language model](https://github.com/huggingface/trl) that has been fine-tuned with reinforcement learning to
|
||||
guide the model outputs according to a value, function, or human feedback. The model can be used for text generation.
|
||||
|
||||
## Usage
|
||||
@ -207,6 +207,7 @@ class PPOTrainer(BaseTrainer):
|
||||
self.model_params = filter(lambda p: p.requires_grad, self.model.parameters())
|
||||
self.is_encoder_decoder = hasattr(self.model, "is_encoder_decoder")
|
||||
self.is_peft_model = getattr(self.model, "is_peft_model", False)
|
||||
self.is_using_text_environment = getattr(config, "use_text_environment", False)
|
||||
|
||||
if isinstance(ref_model, SUPPORTED_ARCHITECTURES):
|
||||
self.ref_model = ref_model
|
||||
@ -344,6 +345,8 @@ class PPOTrainer(BaseTrainer):
|
||||
|
||||
PPODecorators.optimize_cuda_cache = self.config.optimize_cuda_cache
|
||||
|
||||
self.running = RunningMoments(self.accelerator)
|
||||
|
||||
def _filter_kwargs(self, kwargs, target_func):
|
||||
"""
|
||||
filter the keyword arguments that are supported by the target function.
|
||||
@ -388,7 +391,7 @@ class PPOTrainer(BaseTrainer):
|
||||
signature = inspect.signature(self.model.forward)
|
||||
self._signature_columns = list(signature.parameters.keys())
|
||||
# label => sentiment | we need query and response for logging purpose
|
||||
self._signature_columns += list(set(["label", "query", "response"]))
|
||||
self._signature_columns += ["label", "query", "response"]
|
||||
|
||||
# Adapted from transformers.Trainer._remove_unused_columns
|
||||
def _remove_unused_columns(self, dataset: "Dataset"):
|
||||
@ -524,6 +527,7 @@ class PPOTrainer(BaseTrainer):
|
||||
queries: List[torch.LongTensor],
|
||||
responses: List[torch.LongTensor],
|
||||
scores: List[torch.FloatTensor],
|
||||
masks: Optional[List[torch.LongTensor]] = None,
|
||||
):
|
||||
"""
|
||||
Check if the input data is valid for training.
|
||||
@ -537,6 +541,8 @@ class PPOTrainer(BaseTrainer):
|
||||
List of tensors containing the encoded responses of shape (`response_length`)
|
||||
scores (List[`torch.FloatTensor`]):
|
||||
List of tensors containing the scores.
|
||||
masks (List[`torch.LongTensor`], *optional*):
|
||||
list of optional tensors containing the masks of shape (`query_length` + `response_length`)
|
||||
Returns:
|
||||
`tuple`: The input processed data.
|
||||
"""
|
||||
@ -554,6 +560,7 @@ class PPOTrainer(BaseTrainer):
|
||||
queries = [tensor.to(self.current_device) for tensor in queries]
|
||||
responses = [tensor.to(self.current_device) for tensor in responses]
|
||||
scores = [tensor.to(self.current_device) for tensor in scores]
|
||||
masks = [tensor.to(self.current_device) for tensor in masks] if masks is not None else None
|
||||
|
||||
# squeeze scores if needed
|
||||
for i, score in enumerate(scores):
|
||||
@ -562,7 +569,7 @@ class PPOTrainer(BaseTrainer):
|
||||
elif score.dim() == 1:
|
||||
scores[i] = score.squeeze()
|
||||
|
||||
return queries, responses, scores
|
||||
return queries, responses, scores, masks
|
||||
|
||||
@PPODecorators.empty_cuda_cache()
|
||||
def step(
|
||||
@ -570,6 +577,7 @@ class PPOTrainer(BaseTrainer):
|
||||
queries: List[torch.LongTensor],
|
||||
responses: List[torch.LongTensor],
|
||||
scores: List[torch.FloatTensor],
|
||||
response_masks: Optional[List[torch.LongTensor]] = None,
|
||||
):
|
||||
"""
|
||||
Run a PPO optimisation step given a list of queries, model responses, and rewards.
|
||||
@ -581,18 +589,35 @@ class PPOTrainer(BaseTrainer):
|
||||
List of tensors containing the encoded responses of shape (`response_length`)
|
||||
scores (List[`torch.FloatTensor`]):
|
||||
List of tensors containing the scores.
|
||||
response_masks (List[`torch.FloatTensor`], *optional*)):
|
||||
List of tensors containing masks of the response tokens.
|
||||
|
||||
Returns:
|
||||
`dict[str, Any]`: A summary of the training statistics
|
||||
"""
|
||||
bs = self.config.batch_size
|
||||
|
||||
queries, responses, scores = self._step_safety_checker(bs, queries, responses, scores)
|
||||
queries, responses, scores, response_masks = self._step_safety_checker(
|
||||
bs, queries, responses, scores, response_masks
|
||||
)
|
||||
scores = torch.tensor(scores)
|
||||
if self.config.use_score_scaling:
|
||||
# Score scaling
|
||||
scores_mean, scores_std = self.running.update(scores)
|
||||
if self.config.use_score_norm:
|
||||
scores = (scores - self.running.mean) / self.running.std
|
||||
else:
|
||||
scores /= self.running.std
|
||||
|
||||
if self.config.score_clip is not None:
|
||||
# Score clipping
|
||||
scores_dtype = scores.dtype
|
||||
scores = torch.clip(scores.float(), -self.config.score_clip, self.config.score_clip).to(dtype=scores_dtype)
|
||||
|
||||
# if we want to push best model to the hub
|
||||
if hasattr(self, "highest_reward"):
|
||||
if self.compare_step % self.config.compare_steps == 0:
|
||||
curr_mean_reward = torch.tensor(scores).mean()
|
||||
curr_mean_reward = scores.mean()
|
||||
# if the best reward ever seen
|
||||
if curr_mean_reward > self.highest_reward:
|
||||
self.highest_reward = curr_mean_reward
|
||||
@ -639,9 +664,13 @@ class PPOTrainer(BaseTrainer):
|
||||
|
||||
with torch.no_grad():
|
||||
all_logprobs, logits_or_none, values, masks = self.batched_forward_pass(
|
||||
self.model, queries, responses, model_inputs, return_logits=full_kl_penalty
|
||||
self.model,
|
||||
queries,
|
||||
responses,
|
||||
model_inputs,
|
||||
response_masks=response_masks,
|
||||
return_logits=full_kl_penalty,
|
||||
)
|
||||
|
||||
# for when the model is a peft model
|
||||
if self.is_peft_model and hasattr(
|
||||
self.accelerator.unwrap_model(self.model).pretrained_model,
|
||||
@ -864,7 +893,6 @@ class PPOTrainer(BaseTrainer):
|
||||
|
||||
input_data["decoder_input_ids"] = decoder_inputs["input_ids"]
|
||||
input_data["decoder_attention_mask"] = decoder_inputs["attention_mask"]
|
||||
|
||||
else:
|
||||
input_ids = [torch.cat([q, r]) for q, r in zip(queries, responses)]
|
||||
input_data = self.data_collator(
|
||||
@ -872,7 +900,6 @@ class PPOTrainer(BaseTrainer):
|
||||
).to(self.current_device)
|
||||
|
||||
input_data.pop("labels", None) # we don't want to compute LM losses
|
||||
|
||||
return input_data
|
||||
|
||||
@PPODecorators.empty_cuda_cache()
|
||||
@ -883,6 +910,7 @@ class PPOTrainer(BaseTrainer):
|
||||
responses: torch.Tensor,
|
||||
model_inputs: dict,
|
||||
return_logits: bool = False,
|
||||
response_masks: Optional[torch.Tensor] = None,
|
||||
):
|
||||
"""
|
||||
Calculate model outputs in multiple batches.
|
||||
@ -913,6 +941,8 @@ class PPOTrainer(BaseTrainer):
|
||||
input_kwargs = {key: value[i * fbs : (i + 1) * fbs] for key, value in model_inputs.items()}
|
||||
query_batch = queries[i * fbs : (i + 1) * fbs]
|
||||
response_batch = responses[i * fbs : (i + 1) * fbs]
|
||||
if response_masks is not None:
|
||||
response_masks_batch = response_masks[i * fbs : (i + 1) * fbs]
|
||||
logits, _, values = model(**input_kwargs)
|
||||
|
||||
if self.is_encoder_decoder:
|
||||
@ -936,9 +966,15 @@ class PPOTrainer(BaseTrainer):
|
||||
if attention_mask[j, 0] == 0: # offset left padding
|
||||
start += attention_mask[j, :].nonzero()[0]
|
||||
end = start + len(response_batch[j])
|
||||
if response_masks is not None:
|
||||
response_masks_batch[j] = torch.cat(
|
||||
(torch.zeros_like(query_batch[j]), response_masks_batch[j])
|
||||
)[1:]
|
||||
|
||||
masks[j, :start] = 0
|
||||
masks[j, end:] = 0
|
||||
if response_masks is not None:
|
||||
masks[j, start:end] = masks[j, start:end] * response_masks_batch[j][start:end]
|
||||
|
||||
if return_logits:
|
||||
all_logits.append(logits)
|
||||
@ -1186,8 +1222,8 @@ class PPOTrainer(BaseTrainer):
|
||||
mean_non_score_reward = masked_mean(
|
||||
data["non_score_reward"], mask
|
||||
) # non_score_reward is size `batch_size`, `response_length`
|
||||
mean_scores = torch.stack(data["scores"]).mean() # scores is size `batch_size`
|
||||
std_scores = torch.stack(data["scores"]).std()
|
||||
mean_scores = data["scores"].mean() # scores is size `batch_size`
|
||||
std_scores = data["scores"].std()
|
||||
|
||||
if mean_kl.item() < -1.0:
|
||||
# warn users
|
||||
@ -1259,8 +1295,12 @@ class PPOTrainer(BaseTrainer):
|
||||
elif self.config.log_with == "wandb":
|
||||
import wandb
|
||||
|
||||
table_rows = [list(r) for r in zip(batch["query"], batch["response"], rewards.cpu().tolist())]
|
||||
logs.update({"game_log": wandb.Table(columns=["query", "response", "reward"], rows=table_rows)})
|
||||
table_rows = [
|
||||
list(r) for r in zip(batch["query"], batch["response"], batch["answer"], rewards.cpu().tolist())
|
||||
]
|
||||
logs.update(
|
||||
{"game_log": wandb.Table(columns=["query", "response", "answer", "reward"], rows=table_rows)}
|
||||
)
|
||||
# All reduce rewards if distributed
|
||||
if self.is_distributed:
|
||||
import torch.distributed as dist
|
||||
@ -1281,10 +1321,6 @@ class PPOTrainer(BaseTrainer):
|
||||
logs["env/reward_std"] = torch.std(rewards).cpu().numpy().item()
|
||||
logs["env/reward_dist"] = rewards.cpu().numpy()
|
||||
|
||||
logs["env/reward_mean"] = torch.mean(rewards).cpu().numpy().item()
|
||||
logs["env/reward_std"] = torch.std(rewards).cpu().numpy().item()
|
||||
logs["env/reward_dist"] = rewards.cpu().numpy()
|
||||
|
||||
if self.config.log_with == "tensorboard":
|
||||
# update the current step
|
||||
self.current_step += 1
|
||||
@ -1329,3 +1365,18 @@ class PPOTrainer(BaseTrainer):
|
||||
self.accelerator.unwrap_model(self.model).save_pretrained(save_directory)
|
||||
self.tokenizer.save_pretrained(save_directory)
|
||||
self.create_model_card(save_directory)
|
||||
|
||||
def _show_tokens(self, tokens, masks):
|
||||
from rich import print
|
||||
from rich.text import Text
|
||||
|
||||
text = Text()
|
||||
|
||||
for i, (token, mask) in enumerate(zip(tokens, masks)):
|
||||
if mask == 1:
|
||||
text.append(self.tokenizer.decode(token.item()), style="black on deep_sky_blue1")
|
||||
text.append(" ")
|
||||
else:
|
||||
text.append(self.tokenizer.decode(token.item()), style="black on cyan3")
|
||||
text.append(" ")
|
||||
print(text)
|
||||
|
@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import warnings
|
||||
from dataclasses import FrozenInstanceError, replace
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@ -129,7 +130,10 @@ class RewardTrainer(Trainer):
|
||||
data_collator = RewardDataCollatorWithPadding(tokenizer, max_length=max_length)
|
||||
|
||||
if args.remove_unused_columns:
|
||||
args.remove_unused_columns = False
|
||||
try: # for bc before https://github.com/huggingface/transformers/pull/25435
|
||||
args.remove_unused_columns = False
|
||||
except FrozenInstanceError:
|
||||
args = replace(args, remove_unused_columns=False)
|
||||
# warn users
|
||||
warnings.warn(
|
||||
"When using RewardDataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments"
|
||||
|
@ -85,7 +85,7 @@ class SFTTrainer(Trainer):
|
||||
The number of sequences to use for the `ConstantLengthDataset`. Defaults to `1024`.
|
||||
chars_per_token (`Optional[float]`):
|
||||
The number of characters per token to use for the `ConstantLengthDataset`. Defaults to `3.6`. You can check how this is computed in the
|
||||
stack-llama example: https://github.com/lvwerra/trl/blob/08f550674c553c36c51d1027613c29f14f3676a5/examples/stack_llama/scripts/supervised_finetuning.py#L53.
|
||||
stack-llama example: https://github.com/huggingface/trl/blob/08f550674c553c36c51d1027613c29f14f3676a5/examples/stack_llama/scripts/supervised_finetuning.py#L53.
|
||||
packing (`Optional[bool]`):
|
||||
Used only in case `dataset_text_field` is passed. This argument is used by the `ConstantLengthDataset` to pack the sequences
|
||||
of the dataset.
|
||||
|
@ -14,8 +14,9 @@
|
||||
import os
|
||||
import random
|
||||
import warnings
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -61,8 +62,9 @@ class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling):
|
||||
Args:
|
||||
instruction_template (`Optional[str]`): the template form that indicates the start of the human instruction, typically something like
|
||||
'### Human:\n'. Useful for assistant-style conversation datasets
|
||||
response_template (`str`): the template form that indicates the start of the response, typically something like
|
||||
'### Response:\n'
|
||||
response_template (`Union[str, List[int]]`): the template form that indicates the start of the response, typically something like
|
||||
'### Response:\n'. It can also be passed as tokenized ids, which can be useful when using a tokenizer that encodes the response
|
||||
differently if it does not have proper context.
|
||||
mlm (`bool`, *optional*, defaults to `False`): Whether or not to use masked language modeling in the underlying
|
||||
`DataCollatorForLanguageModeling` class. Note that this option currently has no effect but is present
|
||||
for flexibility and backwards-compatibility.
|
||||
@ -72,7 +74,7 @@ class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
response_template: str,
|
||||
response_template: Union[str, List[int]],
|
||||
instruction_template: Optional[str] = None,
|
||||
*args,
|
||||
mlm: bool = False,
|
||||
@ -83,7 +85,11 @@ class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling):
|
||||
self.instruction_template = instruction_template
|
||||
self.response_template = response_template
|
||||
self.ignore_index = ignore_index
|
||||
self.response_token_ids = self.tokenizer.encode(self.response_template, add_special_tokens=False)
|
||||
if type(response_template) == list:
|
||||
# The user already provides the token ids
|
||||
self.response_token_ids = response_template
|
||||
else:
|
||||
self.response_token_ids = self.tokenizer.encode(self.response_template, add_special_tokens=False)
|
||||
|
||||
def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
|
||||
batch = super().torch_call(examples)
|
||||
@ -101,14 +107,18 @@ class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling):
|
||||
response_token_ids_start_idx = idx
|
||||
|
||||
if response_token_ids_start_idx is None:
|
||||
raise RuntimeError(
|
||||
f'Could not find response key {self.response_token_ids} in token IDs {batch["labels"][i]}'
|
||||
warnings.warn(
|
||||
f"Could not find response key `{self.response_template}` in the "
|
||||
f'following instance: {self.tokenizer.decode(batch["input_ids"][i])} '
|
||||
f"This instance will be ignored in loss calculation. "
|
||||
f"Note, if this happens often, consider increasing the `max_seq_length`."
|
||||
)
|
||||
batch["labels"][i, :] = self.ignore_index
|
||||
else:
|
||||
response_token_ids_end_idx = response_token_ids_start_idx + len(self.response_token_ids)
|
||||
|
||||
response_token_ids_end_idx = response_token_ids_start_idx + len(self.response_token_ids)
|
||||
|
||||
# Make pytorch loss function ignore all tokens up through the end of the response key
|
||||
batch["labels"][i, :response_token_ids_end_idx] = self.ignore_index
|
||||
# Make pytorch loss function ignore all tokens up through the end of the response key
|
||||
batch["labels"][i, :response_token_ids_end_idx] = self.ignore_index
|
||||
|
||||
else:
|
||||
for i in range(len(examples)):
|
||||
@ -123,10 +133,14 @@ class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling):
|
||||
):
|
||||
response_token_ids_idxs.append(assistant_idx + len(self.response_token_ids))
|
||||
|
||||
if len(self.response_token_ids) == 0:
|
||||
raise RuntimeError(
|
||||
f'Could not find response key {self.response_token_ids} in token IDs {batch["labels"][i]}'
|
||||
if len(response_token_ids_idxs) == 0:
|
||||
warnings.warn(
|
||||
f"Could not find response key `{self.response_template}` in the "
|
||||
f'following instance: {self.tokenizer.decode(batch["input_ids"][i])} '
|
||||
f"This instance will be ignored in loss calculation. "
|
||||
f"Note, if this happens often, consider increasing the `max_seq_length`."
|
||||
)
|
||||
batch["labels"][i, :] = self.ignore_index
|
||||
|
||||
human_token_ids = self.tokenizer.encode(self.instruction_template, add_special_tokens=False)
|
||||
for human_idx in np.where(batch["labels"][i] == human_token_ids[0])[0]:
|
||||
@ -135,9 +149,13 @@ class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling):
|
||||
human_token_ids_idxs.append(human_idx)
|
||||
|
||||
if len(human_token_ids_idxs) == 0:
|
||||
raise RuntimeError(
|
||||
f'Could not find response key {human_token_ids} in token IDs {batch["labels"][i]}'
|
||||
warnings.warn(
|
||||
f"Could not find instruction key `{self.instruction_template}` in the "
|
||||
f'following instance: {self.tokenizer.decode(batch["input_ids"][i])} '
|
||||
f"This instance will be ignored in loss calculation. "
|
||||
f"Note, if this happens often, consider increasing the `max_seq_length`."
|
||||
)
|
||||
batch["labels"][i, :] = self.ignore_index
|
||||
|
||||
for idx, (start, end) in enumerate(zip(human_token_ids_idxs, response_token_ids_idxs)):
|
||||
# Make pytorch loss function ignore all non response tokens
|
||||
@ -243,7 +261,7 @@ class DPODataCollatorWithPadding:
|
||||
padding_value (`int`, defaults to 0):
|
||||
The value used for padding.
|
||||
truncation_mode: (`str`, defaults to "keep_end"):
|
||||
The truncation mode to use when truncating the prompt + chosen/rejected responses.
|
||||
The truncation mode to use when truncating the prompt.
|
||||
"""
|
||||
tokenizer: PreTrainedTokenizerBase
|
||||
padding: Union[bool, str] = True
|
||||
@ -499,6 +517,65 @@ class PeftSavingCallback(TrainerCallback):
|
||||
os.remove(os.path.join(checkpoint_path, "pytorch_model.bin"))
|
||||
|
||||
|
||||
class RunningMoments:
|
||||
def __init__(self, accelerator):
|
||||
"""
|
||||
Calculates the running mean and standard deviation of a data stream. Reference:
|
||||
https://github.com/OpenLMLab/MOSS-RLHF/blob/40b91eb2f2b71b16919addede0341d2bef70825d/utils.py#L75
|
||||
"""
|
||||
self.mean = 0
|
||||
self.std = 1
|
||||
self.var = 1
|
||||
self.count = 1e-24
|
||||
self.accelerator = accelerator
|
||||
|
||||
@torch.no_grad()
|
||||
def update(self, xs: torch.Tensor) -> Tuple[float, float]:
|
||||
"""
|
||||
Updates running moments from batch's moments computed across ranks
|
||||
"""
|
||||
if self.accelerator.use_distributed:
|
||||
xs_mean, xs_var, xs_count = get_global_statistics(self.accelerator, xs)
|
||||
else:
|
||||
xs_count = xs.numel()
|
||||
xs_var, xs_mean = torch.var_mean(xs, unbiased=False)
|
||||
xs_mean, xs_var = xs_mean.float(), xs_var.float()
|
||||
|
||||
delta = xs_mean - self.mean
|
||||
tot_count = self.count + xs_count
|
||||
|
||||
new_sum = xs_var * xs_count
|
||||
# correct old_sum deviation accounting for the new mean
|
||||
old_sum = self.var * self.count + delta**2 * self.count * xs_count / tot_count
|
||||
tot_sum = old_sum + new_sum
|
||||
|
||||
self.mean += delta * xs_count / tot_count
|
||||
self.var = tot_sum / tot_count
|
||||
self.std = (self.var * tot_count / (tot_count - 1)).float().sqrt()
|
||||
self.count = tot_count
|
||||
|
||||
return xs_mean.item(), (xs_var * xs_count / (xs_count - 1)).float().sqrt().item()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def get_global_statistics(accelerator, xs: torch.Tensor, mask=None, device="cpu") -> Tuple[float, float, int]:
|
||||
"""
|
||||
Computes element-wise mean and variance of the tensor across processes. Reference:
|
||||
https://github.com/OpenLMLab/MOSS-RLHF/blob/40b91eb2f2b71b16919addede0341d2bef70825d/utils.py#L57C1-L73C75
|
||||
"""
|
||||
xs = xs.to(accelerator.device)
|
||||
sum_and_count = torch.tensor([xs.sum(), (xs.numel() if mask is None else mask.sum())], device=xs.device)
|
||||
sum_and_count = accelerator.reduce(sum_and_count)
|
||||
global_sum, count = sum_and_count
|
||||
global_mean = global_sum / count
|
||||
|
||||
sum_var = torch.sum(((xs - global_mean) ** 2).mul(1 if mask is None else mask))
|
||||
sum_var = accelerator.reduce(sum_var)
|
||||
global_var = sum_var / count
|
||||
|
||||
return global_mean.to(device), global_var.to(device), count.to(device)
|
||||
|
||||
|
||||
def compute_accuracy(eval_pred) -> Dict[str, float]:
|
||||
predictions, labels = eval_pred
|
||||
# Here, predictions is rewards_chosen and rewards_rejected.
|
||||
@ -522,3 +599,58 @@ def pad_to_length(tensor: torch.Tensor, length: int, pad_value: Union[int, float
|
||||
],
|
||||
dim=dim,
|
||||
)
|
||||
|
||||
|
||||
def disable_dropout_in_model(model: torch.nn.Module) -> None:
|
||||
for module in model.modules():
|
||||
if isinstance(module, torch.nn.Dropout):
|
||||
module.p = 0
|
||||
|
||||
|
||||
def exact_div(a, b, a_str, b_str, custom_error_message=""):
|
||||
q = a // b
|
||||
if a != q * b:
|
||||
raise ValueError(f"{custom_error_message}, {a_str}={a}, {b_str}={b}, inexact division: {a} / {b} = {a / b}")
|
||||
return q
|
||||
|
||||
|
||||
# copied from https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/stat_tracking.py#L5
|
||||
class PerPromptStatTracker:
|
||||
r"""
|
||||
Class for tracking statistics per prompt. Mainly used to calculate advantage for the DPPO algorithm
|
||||
|
||||
Args:
|
||||
buffer_size (`int`):
|
||||
Size of the buffer to keep for each prompt.
|
||||
min_count (`int`):
|
||||
Minimum number of samples to keep in the buffer before calculating the mean and std.
|
||||
"""
|
||||
|
||||
def __init__(self, buffer_size, min_count):
|
||||
self.buffer_size = buffer_size
|
||||
self.min_count = min_count
|
||||
self.stats = {}
|
||||
|
||||
def update(self, prompts, rewards):
|
||||
prompts = np.array(prompts)
|
||||
rewards = np.array(rewards)
|
||||
unique = np.unique(prompts)
|
||||
advantages = np.empty_like(rewards)
|
||||
for prompt in unique:
|
||||
prompt_rewards = rewards[prompts == prompt]
|
||||
if prompt not in self.stats:
|
||||
self.stats[prompt] = deque(maxlen=self.buffer_size)
|
||||
self.stats[prompt].extend(prompt_rewards)
|
||||
|
||||
if len(self.stats[prompt]) < self.min_count:
|
||||
mean = np.mean(rewards)
|
||||
std = np.std(rewards) + 1e-6
|
||||
else:
|
||||
mean = np.mean(self.stats[prompt])
|
||||
std = np.std(self.stats[prompt]) + 1e-6
|
||||
advantages[prompts == prompt] = (prompt_rewards - mean) / std
|
||||
|
||||
return advantages
|
||||
|
||||
def get_stats(self):
|
||||
return {k: {"mean": np.mean(v), "std": np.std(v), "count": len(v)} for k, v in self.stats.items()}
|
||||
|
Reference in New Issue
Block a user