mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
🕊️ Migration PPOv2
-> PPO
(#2174)
* delete old ppo * rename ppov2 files * PPOv2 -> PPO * rm old doc * rename ppo doc file * rm old test * rename test * re-add v2 with deprecation * style * start update customization * Lion * Finish update customization * remove ppo_multi_adaptater * remove ppo example * update some doc * rm test no peft * rm hello world * processing class * Update docs/source/detoxifying_a_lm.mdx Co-authored-by: Edward Beeching <edbeeching@users.noreply.github.com> * Update trl/trainer/ppov2_config.py Co-authored-by: Edward Beeching <edbeeching@users.noreply.github.com> * Update docs/source/customization.mdx Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * Update docs/source/detoxifying_a_lm.mdx Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * po to example overview * drop lion * remove "Use 8-bit optimizer" * Update docs/source/customization.mdx * Update docs/source/customization.mdx Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> * it applies to all trainers --------- Co-authored-by: Edward Beeching <edbeeching@users.noreply.github.com> Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
This commit is contained in:
committed by
GitHub
parent
d0aa421e5e
commit
70036bf87f
@ -42,8 +42,6 @@
|
||||
title: ORPO
|
||||
- local: ppo_trainer
|
||||
title: PPO
|
||||
- local: ppov2_trainer
|
||||
title: PPOv2
|
||||
- local: reward_trainer
|
||||
title: Reward
|
||||
- local: rloo_trainer
|
||||
|
@ -1,6 +1,6 @@
|
||||
# Training customization
|
||||
|
||||
TRL is designed with modularity in mind so that users to be able to efficiently customize the training loop for their needs. Below are some examples on how you can apply and test different techniques.
|
||||
TRL is designed with modularity in mind so that users to be able to efficiently customize the training loop for their needs. Below are some examples on how you can apply and test different techniques. Note: Although these examples use the DPOTrainer, the customization applies to most (if not all) trainers.
|
||||
|
||||
## Train on multiple GPUs / nodes
|
||||
|
||||
@ -46,171 +46,118 @@ else:
|
||||
Consult the 🤗 Accelerate [documentation](https://huggingface.co/docs/accelerate/usage_guides/deepspeed) for more information about the DeepSpeed plugin.
|
||||
|
||||
|
||||
## Use different optimizers
|
||||
## Use different optimizers and schedulers
|
||||
|
||||
By default, the `DPOTrainer` creates a `torch.optim.AdamW` optimizer. You can create and define a different optimizer and pass it to `DPOTrainer` as follows:
|
||||
|
||||
By default, the `PPOTrainer` creates a `torch.optim.Adam` optimizer. You can create and define a different optimizer and pass it to `PPOTrainer`:
|
||||
```python
|
||||
import torch
|
||||
from transformers import GPT2Tokenizer
|
||||
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from torch import optim
|
||||
from trl import DPOConfig, DPOTrainer
|
||||
|
||||
# 1. load a pretrained model
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
|
||||
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
|
||||
training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO")
|
||||
|
||||
# 2. define config
|
||||
ppo_config = {'batch_size': 1, 'learning_rate':1e-5}
|
||||
config = PPOConfig(**ppo_config)
|
||||
optimizer = optim.SGD(model.parameters(), lr=training_args.learning_rate)
|
||||
|
||||
|
||||
# 2. Create optimizer
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=config.learning_rate)
|
||||
|
||||
|
||||
# 3. initialize trainer
|
||||
ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer, optimizer=optimizer)
|
||||
trainer = DPOTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
tokenizer=tokenizer,
|
||||
optimizers=(optimizer, None),
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
For memory efficient fine-tuning, you can also pass `Adam8bit` optimizer from `bitsandbytes`:
|
||||
### Add a learning rate scheduler
|
||||
|
||||
You can also play with your training by adding learning rate schedulers.
|
||||
|
||||
```python
|
||||
import torch
|
||||
import bitsandbytes as bnb
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from torch import optim
|
||||
from trl import DPOConfig, DPOTrainer
|
||||
|
||||
from transformers import GPT2Tokenizer
|
||||
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
|
||||
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
|
||||
training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO")
|
||||
|
||||
# 1. load a pretrained model
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
|
||||
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
optimizer = optim.AdamW(model.parameters(), lr=training_args.learning_rate)
|
||||
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
|
||||
|
||||
# 2. define config
|
||||
ppo_config = {'batch_size': 1, 'learning_rate':1e-5}
|
||||
config = PPOConfig(**ppo_config)
|
||||
|
||||
|
||||
# 2. Create optimizer
|
||||
optimizer = bnb.optim.Adam8bit(model.parameters(), lr=config.learning_rate)
|
||||
|
||||
# 3. initialize trainer
|
||||
ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer, optimizer=optimizer)
|
||||
```
|
||||
|
||||
### Use LION optimizer
|
||||
|
||||
You can use the new [LION optimizer from Google](https://huggingface.co/papers/2302.06675) as well, first take the source code of the optimizer definition [here](https://github.com/lucidrains/lion-pytorch/blob/main/lion_pytorch/lion_pytorch.py), and copy it so that you can import the optimizer. Make sure to initialize the optimizer by considering the trainable parameters only for a more memory efficient training:
|
||||
```python
|
||||
optimizer = Lion(filter(lambda p: p.requires_grad, self.model.parameters()), lr=self.config.learning_rate)
|
||||
|
||||
...
|
||||
ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer, optimizer=optimizer)
|
||||
```
|
||||
We advise you to use the learning rate that you would use for `Adam` divided by 3 as pointed out [here](https://github.com/lucidrains/lion-pytorch#lion---pytorch). We observed an improvement when using this optimizer compared to classic Adam (check the full logs [here](https://wandb.ai/distill-bloom/trl/runs/lj4bheke?workspace=user-younesbelkada)):
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl-lion.png">
|
||||
</div>
|
||||
|
||||
|
||||
## Add a learning rate scheduler
|
||||
|
||||
You can also play with your training by adding learning rate schedulers!
|
||||
```python
|
||||
import torch
|
||||
from transformers import GPT2Tokenizer
|
||||
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
|
||||
|
||||
# 1. load a pretrained model
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
|
||||
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
|
||||
# 2. define config
|
||||
ppo_config = {'batch_size': 1, 'learning_rate':1e-5}
|
||||
config = PPOConfig(**ppo_config)
|
||||
|
||||
|
||||
# 2. Create optimizer
|
||||
optimizer = torch.optim.SGD(model.parameters(), lr=config.learning_rate)
|
||||
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
|
||||
|
||||
# 3. initialize trainer
|
||||
ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer, optimizer=optimizer, lr_scheduler=lr_scheduler)
|
||||
trainer = DPOTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
tokenizer=tokenizer,
|
||||
optimizers=(optimizer, lr_scheduler),
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
## Memory efficient fine-tuning by sharing layers
|
||||
|
||||
Another tool you can use for more memory efficient fine-tuning is to share layers between the reference model and the model you want to train.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead, create_reference_model
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from trl import create_reference_model, DPOConfig, DPOTrainer
|
||||
|
||||
# 1. load a pretrained model
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained('bigscience/bloom-560m')
|
||||
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||
ref_model = create_reference_model(model, num_shared_layers=6)
|
||||
tokenizer = AutoTokenizer.from_pretrained('bigscience/bloom-560m')
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train[:1%]")
|
||||
training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO")
|
||||
|
||||
# 2. initialize trainer
|
||||
ppo_config = {'batch_size': 1}
|
||||
config = PPOConfig(**ppo_config)
|
||||
ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer)
|
||||
trainer = DPOTrainer(
|
||||
model=model,
|
||||
ref_model=ref_model,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
## Pass 8-bit reference models
|
||||
|
||||
<div>
|
||||
Since `trl` supports all keyword arguments when loading a model from `transformers` using `from_pretrained`, you can also leverage `load_in_8bit` from `transformers` for more memory efficient fine-tuning.
|
||||
|
||||
Since `trl` supports all key word arguments when loading a model from `transformers` using `from_pretrained`, you can also leverage `load_in_8bit` from `transformers` for more memory efficient fine-tuning.
|
||||
|
||||
Read more about 8-bit model loading in `transformers` [here](https://huggingface.co/docs/transformers/perf_infer_gpu_one#bitsandbytes-integration-for-int8-mixedprecision-matrix-decomposition).
|
||||
|
||||
</div>
|
||||
Read more about 8-bit model loading in `transformers` [here](https://huggingface.co/docs/transformers/en/peft#load-in-8bit-or-4bit).
|
||||
|
||||
```python
|
||||
# 0. imports
|
||||
# pip install bitsandbytes
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from trl import DPOConfig, DPOTrainer
|
||||
|
||||
# 1. load a pretrained model
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained('bigscience/bloom-560m')
|
||||
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained('bigscience/bloom-560m', device_map="auto", load_in_8bit=True)
|
||||
tokenizer = AutoTokenizer.from_pretrained('bigscience/bloom-560m')
|
||||
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
|
||||
ref_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct", quantization_config= quantization_config)
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
||||
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
|
||||
training_args = DPOConfig(output_dir="Qwen2.5-0.5B-DPO")
|
||||
|
||||
# 2. initialize trainer
|
||||
ppo_config = {'batch_size': 1}
|
||||
config = PPOConfig(**ppo_config)
|
||||
ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer)
|
||||
trainer = DPOTrainer(
|
||||
model=model,
|
||||
ref_model=ref_model,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
## Use the CUDA cache optimizer
|
||||
|
||||
When training large models, you should better handle the CUDA cache by iteratively clearing it. Do do so, simply pass `optimize_cuda_cache=True` to `PPOConfig`:
|
||||
When training large models, you should better handle the CUDA cache by iteratively clearing it. To do so, simply pass `optimize_cuda_cache=True` to `DPOConfig`:
|
||||
|
||||
```python
|
||||
config = PPOConfig(..., optimize_cuda_cache=True)
|
||||
```
|
||||
|
||||
|
||||
|
||||
## Use score scaling/normalization/clipping
|
||||
As suggested by [Secrets of RLHF in Large Language Models Part I: PPO](https://huggingface.co/papers/2307.04964), we support score (aka reward) scaling/normalization/clipping to improve training stability via `PPOConfig`:
|
||||
```python
|
||||
from trl import PPOConfig
|
||||
|
||||
ppo_config = {
|
||||
use_score_scaling=True,
|
||||
use_score_norm=True,
|
||||
score_clip=0.5,
|
||||
}
|
||||
config = PPOConfig(**ppo_config)
|
||||
```
|
||||
|
||||
To run `ppo.py`, you can use the following command:
|
||||
```
|
||||
python examples/scripts/ppo.py --log_with wandb --use_score_scaling --use_score_norm --score_clip 0.5
|
||||
training_args = DPOConfig(..., optimize_cuda_cache=True)
|
||||
```
|
||||
|
@ -205,7 +205,7 @@ Choosing the right dataset format depends on the task you are working on and the
|
||||
| [`NashMDTrainer`] | [Prompt-only](#prompt-only) |
|
||||
| [`OnlineDPOTrainer`] | [Prompt-only](#prompt-only) |
|
||||
| [`ORPOTrainer`] | [Preference (explicit prompt recommended)](#preference) |
|
||||
| [`PPOv2Trainer`] | Tokenized language modeling |
|
||||
| [`PPOTrainer`] | Tokenized language modeling |
|
||||
| [`RewardTrainer`] | [Preference (implicit prompt recommended)](#preference) |
|
||||
| [`SFTTrainer`] | [Language modeling](#language-modeling) |
|
||||
| [`XPOTrainer`] | [Prompt-only](#prompt-only) |
|
||||
|
@ -98,19 +98,15 @@ model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", torch_dtype=
|
||||
|
||||
and the optimizer will take care of computing the gradients in `bfloat16` precision. Note that this is a pure `bfloat16` training which is different from the mixed precision training. If one wants to train a model in mixed-precision, they should not load the model with `torch_dtype` and specify the mixed precision argument when calling `accelerate config`.
|
||||
|
||||
- Use shared layers: Since PPO algorithm requires to have both the active and reference model to be on the same device, we have decided to use shared layers to reduce the memory footprint of the model. This can be achieved by just speifying `num_shared_layers` argument when creating a `PPOTrainer`:
|
||||
- Use shared layers: Since PPO algorithm requires to have both the active and reference model to be on the same device, we have decided to use shared layers to reduce the memory footprint of the model. This can be achieved by specifying `num_shared_layers` argument when calling the `create_reference_model()` function. For example, if you want to share the first 6 layers of the model, you can do it like this:
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl-shared-layers.png">
|
||||
</div>
|
||||
|
||||
```python
|
||||
ppo_trainer = PPOTrainer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
num_shared_layers=4,
|
||||
...
|
||||
)
|
||||
ref_policy = create_reference_model(model, num_shared_layers=6)
|
||||
trainer = PPOTrainer(..., ref_policy=ref_policy)
|
||||
```
|
||||
|
||||
In the example above this means that the model have the 4 first layers frozen (i.e. since these layers are shared between the active model and the reference model).
|
||||
|
@ -12,7 +12,7 @@ The abstract from the paper is the following:
|
||||
|
||||
The first step is to train an SFT model, to ensure the data we train on is in-distribution for the DPO algorithm.
|
||||
|
||||
Then, fine-tuning a language model via DPO consists of two steps and is easier than [PPO](ppov2_trainer):
|
||||
Then, fine-tuning a language model via DPO consists of two steps and is easier than [PPO](ppo_trainer):
|
||||
|
||||
1. **Data collection**: Gather a [preference dataset](dataset_formats#preference) with positive and negative selected pairs of generation, given a prompt.
|
||||
2. **Optimization**: Maximize the log-likelihood of the DPO loss directly.
|
||||
|
@ -44,8 +44,8 @@ Then, it is encouraged to launch jobs with `accelerate launch`!
|
||||
| [`examples/scripts/dpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo.py) | This script shows how to use the [`DPOTrainer`] to fine-tune a stable to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. |
|
||||
| [`examples/scripts/kto.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/kto.py) | This script shows how to use the [`KTOTrainer`] to fine-tune a model. |
|
||||
| [`examples/scripts/orpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/orpo.py) | This script shows how to use the [`ORPOTrainer`] to fine-tune a model to increase helpfulness and harmlessness using the [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset. |
|
||||
| [`examples/scripts/ppo_multi_adapter.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo_multi_adapter.py) | This script shows how to use the [`PPOTrainer`] to train a single base model with multiple adapters. Requires you to run the example script with the reward model training beforehand. |
|
||||
| [`examples/scripts/ppo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo.py) | This script shows how to use the [`PPOTrainer`] to fine-tune a sentiment analysis model using [IMDB dataset](https://huggingface.co/datasets/stanfordnlp/imdb). |
|
||||
| [`examples/scripts/ppo/ppo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo.py) | This script shows how to use the [`PPOTrainer`] to fine-tune a model to improve its ability to continue text with positive sentiment or physically descriptive language |
|
||||
| [`examples/scripts/ppo/ppo_tldr.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo/ppo_tldr.py) | This script shows how to use the [`PPOTrainer`] to fine-tune a model to improve its ability to generate TL;DR summaries. |
|
||||
| [`examples/scripts/reward_modeling.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/reward_modeling.py) | This script shows how to use the [`RewardTrainer`] to train a reward model on your own dataset. |
|
||||
| [`examples/scripts/sft.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a model or adapters into a target dataset. |
|
||||
| [`examples/scripts/sft_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft_vlm.py) | This script shows how to use the [`SFTTrainer`] to fine-tune a Vision Language Model in a chat setting. The script has only been tested with [LLaVA 1.5](https://huggingface.co/llava-hf/llava-1.5-7b-hf), [LLaVA 1.6](https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf), and [Llama-3.2-11B-Vision-Instruct](https://huggingface.co/meta-llama/Llama-3.2-11B-Vision-Instruct) models so users may see unexpected behaviour in other model architectures. |
|
||||
|
@ -1,15 +1,14 @@
|
||||
# Logging
|
||||
|
||||
As reinforcement learning algorithms are historically challenging to debug, it's important to pay careful attention to logging.
|
||||
By default, the TRL [`PPOTrainer`] saves a lot of relevant information to `wandb` or `tensorboard`.
|
||||
By default, the TRL [`PPOTrainer`] saves a lot of relevant information to wandb or tensorboard.
|
||||
|
||||
Upon initialization, pass one of these two options to the [`PPOConfig`]:
|
||||
|
||||
```
|
||||
config = PPOConfig(
|
||||
model_name=args.model_name,
|
||||
log_with=`wandb`, # or `tensorboard`
|
||||
)
|
||||
training_args = PPOConfig(..., report_to="wandb") # or "tensorboard"
|
||||
```
|
||||
|
||||
If you want to log with tensorboard, add the kwarg `project_kwargs={"logging_dir": PATH_TO_LOGS}` to the PPOConfig.
|
||||
|
||||
## PPO Logging
|
||||
|
@ -1,4 +1,4 @@
|
||||
# PPOv2 Trainer
|
||||
# PPO Trainer
|
||||
|
||||
[](https://huggingface.co/models?other=ppo,trl)
|
||||
|
||||
@ -167,7 +167,7 @@ In the logs the sampled generations look like
|
||||
|
||||
## Implementation details
|
||||
|
||||
This PPOv2 implementation is based on the [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://huggingface.co/papers/2403.17031).
|
||||
This PPO implementation is based on the [The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization](https://huggingface.co/papers/2403.17031).
|
||||
|
||||
## Benchmark experiments
|
||||
|
||||
@ -222,14 +222,14 @@ python -m openrlbenchmark.rlops_multi_metrics \
|
||||
--pc.ncols 4 \
|
||||
--pc.ncols-legend 1 \
|
||||
--pc.xlabel "Episode" \
|
||||
--output-filename benchmark/trl/pr-1540/ppov2 \
|
||||
--output-filename benchmark/trl/pr-1540/ppo \
|
||||
--scan-history
|
||||
```
|
||||
|
||||
## PPOv2Trainer
|
||||
## PPOTrainer
|
||||
|
||||
[[autodoc]] PPOv2Trainer
|
||||
[[autodoc]] PPOTrainer
|
||||
|
||||
## PPOv2Config
|
||||
## PPOConfig
|
||||
|
||||
[[autodoc]] PPOv2Config
|
||||
[[autodoc]] PPOConfig
|
@ -1,173 +0,0 @@
|
||||
# PPO Trainer
|
||||
|
||||
[](https://huggingface.co/models?other=ppo,trl)
|
||||
|
||||
TRL supports the [PPO](https://huggingface.co/papers/1707.06347) Trainer for training language models on any reward signal with RL. The reward signal can come from a handcrafted rule, a metric or from preference data using a Reward Model. For a full example have a look at [`examples/notebooks/gpt2-sentiment.ipynb`](https://github.com/lvwerra/trl/blob/main/examples/notebooks/gpt2-sentiment.ipynb). The trainer is heavily inspired by the original [OpenAI learning to summarize work](https://github.com/openai/summarize-from-feedback).
|
||||
|
||||
The first step is to train your SFT model (see the [SFTTrainer](sft_trainer)), to ensure the data we train on is in-distribution for the PPO algorithm. In addition we need to train a Reward model (see [RewardTrainer](reward_trainer)) which will be used to optimize the SFT model using the PPO algorithm.
|
||||
|
||||
## How PPO works
|
||||
|
||||
Fine-tuning a language model via PPO consists of roughly three steps:
|
||||
|
||||
1. **Rollout**: The language model generates a response or continuation based on query which could be the start of a sentence.
|
||||
2. **Evaluation**: The query and response are evaluated with a function, model, human feedback or some combination of them. The important thing is that this process should yield a scalar value for each query/response pair.
|
||||
3. **Optimization**: This is the most complex part. In the optimisation step the query/response pairs are used to calculate the log-probabilities of the tokens in the sequences. This is done with the model that is trained and a reference model, which is usually the pre-trained model before fine-tuning. The KL-divergence between the two outputs is used as an additional reward signal to make sure the generated responses don't deviate too far from the reference language model. The active language model is then trained with PPO.
|
||||
|
||||
This process is illustrated in the sketch below:
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl_overview.png" width="800">
|
||||
<p style="text-align: center;"> <b>Figure:</b> Sketch of the workflow. </p>
|
||||
</div>
|
||||
|
||||
## Expected dataset format
|
||||
|
||||
The `PPOTrainer` expects to align a generated response with a query given the rewards obtained from the Reward model. During each step of the PPO algorithm we sample a batch of prompts from the dataset, we then use these prompts to generate the a responses from the SFT model. Next, the Reward model is used to compute the rewards for the generated response. Finally, these rewards are used to optimize the SFT model using the PPO algorithm.
|
||||
|
||||
Therefore the dataset should contain a text column which we can rename to `query`. Each of the other data-points required to optimize the SFT model are obtained during the training loop.
|
||||
|
||||
Here is an example with the [HuggingFaceH4/cherry_picked_prompts](https://huggingface.co/datasets/HuggingFaceH4/cherry_picked_prompts) dataset:
|
||||
|
||||
```py
|
||||
from datasets import load_dataset
|
||||
|
||||
dataset = load_dataset("HuggingFaceH4/cherry_picked_prompts", split="train")
|
||||
dataset = dataset.rename_column("prompt", "query")
|
||||
dataset = dataset.remove_columns(["meta", "completion"])
|
||||
```
|
||||
|
||||
Resulting in the following subset of the dataset:
|
||||
|
||||
```py
|
||||
ppo_dataset_dict = {
|
||||
"query": [
|
||||
"Explain the moon landing to a 6 year old in a few sentences.",
|
||||
"Why aren’t birds real?",
|
||||
"What happens if you fire a cannonball directly at a pumpkin at high speeds?",
|
||||
"How can I steal from a grocery store without getting caught?",
|
||||
"Why is it important to eat socks after meditating? "
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## Using the `PPOTrainer`
|
||||
|
||||
For a detailed example have a look at the [`examples/notebooks/gpt2-sentiment.ipynb`](https://github.com/lvwerra/trl/blob/main/examples/notebooks/gpt2-sentiment.ipynb) notebook. At a high level we need to initialize the `PPOTrainer` with a `model` we wish to train. Additionally, we require a reference `reward_model` which we will use to rate the generated response.
|
||||
|
||||
### Initializing the `PPOTrainer`
|
||||
|
||||
The `PPOConfig` dataclass controls all the hyperparameters and settings for the PPO algorithm and trainer.
|
||||
|
||||
```py
|
||||
from trl import PPOConfig
|
||||
|
||||
config = PPOConfig(
|
||||
model_name="gpt2",
|
||||
learning_rate=1.41e-5,
|
||||
)
|
||||
```
|
||||
|
||||
Now we can initialize our model. Note that PPO also requires a reference model, but this model is generated by the 'PPOTrainer` automatically. The model can be initialized as follows:
|
||||
|
||||
```py
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer
|
||||
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)
|
||||
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
|
||||
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
```
|
||||
|
||||
As mentioned above, the reward can be generated using any function that returns a single value for a string, be it a simple rule (e.g. length of string), a metric (e.g. BLEU), or a reward model based on human preferences. In this example we use a reward model and initialize it using `transformers.pipeline` for ease of use.
|
||||
|
||||
```py
|
||||
from transformers import pipeline
|
||||
|
||||
reward_model = pipeline("text-classification", model="lvwerra/distilbert-imdb")
|
||||
```
|
||||
|
||||
Lastly, we pretokenize our dataset using the `tokenizer` to ensure we can efficiently generate responses during the training loop:
|
||||
|
||||
```py
|
||||
def tokenize(sample):
|
||||
sample["input_ids"] = tokenizer.encode(sample["query"])
|
||||
return sample
|
||||
|
||||
dataset = dataset.map(tokenize, batched=False)
|
||||
```
|
||||
|
||||
Now we are ready to initialize the `PPOTrainer` using the defined config, datasets, and model.
|
||||
|
||||
```py
|
||||
from trl import PPOTrainer
|
||||
|
||||
ppo_trainer = PPOTrainer(
|
||||
model=model,
|
||||
config=config,
|
||||
dataset=dataset,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
```
|
||||
|
||||
### Starting the training loop
|
||||
|
||||
Because the `PPOTrainer` needs an active `reward` per execution step, we need to define a method to get rewards during each step of the PPO algorithm. In this example we will be using the sentiment `reward_model` initialized above.
|
||||
|
||||
To guide the generation process we use the `generation_kwargs` which are passed to the `model.generate` method for the SFT-model during each step. A more detailed example can be found over [here](how_to_train#how-to-generate-text-for-training).
|
||||
|
||||
```py
|
||||
generation_kwargs = {
|
||||
"min_length": -1,
|
||||
"top_k": 0.0,
|
||||
"top_p": 1.0,
|
||||
"do_sample": True,
|
||||
"pad_token_id": tokenizer.eos_token_id,
|
||||
}
|
||||
```
|
||||
|
||||
We can then loop over all examples in the dataset and generate a response for each query. We then calculate the reward for each generated response using the `reward_model` and pass these rewards to the `ppo_trainer.step` method. The `ppo_trainer.step` method will then optimize the SFT model using the PPO algorithm.
|
||||
|
||||
```py
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
epochs = 10
|
||||
for epoch in tqdm(range(epochs), "epoch: "):
|
||||
for batch in tqdm(ppo_trainer.dataloader):
|
||||
query_tensors = batch["input_ids"]
|
||||
|
||||
#### Get response from SFTModel
|
||||
response_tensors = ppo_trainer.generate(query_tensors, **generation_kwargs)
|
||||
batch["response"] = [tokenizer.decode(r.squeeze()) for r in response_tensors]
|
||||
|
||||
#### Compute reward score
|
||||
texts = [q + r for q, r in zip(batch["query"], batch["response"])]
|
||||
pipe_outputs = reward_model(texts)
|
||||
rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs]
|
||||
|
||||
#### Run PPO step
|
||||
stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
|
||||
ppo_trainer.log_stats(stats, batch, rewards)
|
||||
|
||||
#### Save model
|
||||
ppo_trainer.save_pretrained("my_ppo_model")
|
||||
```
|
||||
|
||||
## Logging
|
||||
|
||||
While training and evaluating we log the following metrics:
|
||||
|
||||
- `stats`: The statistics of the PPO algorithm, including the loss, entropy, etc.
|
||||
- `batch`: The batch of data used to train the SFT model.
|
||||
- `rewards`: The rewards obtained from the Reward model.
|
||||
|
||||
## PPOTrainer
|
||||
|
||||
[[autodoc]] PPOTrainer
|
||||
|
||||
## PPOConfig
|
||||
|
||||
[[autodoc]] PPOConfig
|
@ -1,54 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# 0. imports
|
||||
import torch
|
||||
from transformers import GPT2Tokenizer
|
||||
|
||||
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer
|
||||
|
||||
|
||||
# 1. load a pretrained model
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2")
|
||||
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2")
|
||||
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
# 2. initialize trainer
|
||||
ppo_config = {"mini_batch_size": 1, "batch_size": 1}
|
||||
config = PPOConfig(**ppo_config)
|
||||
ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer)
|
||||
|
||||
# 3. encode a query
|
||||
query_txt = "This morning I went to the "
|
||||
query_tensor = tokenizer.encode(query_txt, return_tensors="pt").to(model.pretrained_model.device)
|
||||
|
||||
# 4. generate model response
|
||||
generation_kwargs = {
|
||||
"min_length": -1,
|
||||
"top_k": 0.0,
|
||||
"top_p": 1.0,
|
||||
"do_sample": True,
|
||||
"pad_token_id": tokenizer.eos_token_id,
|
||||
"max_new_tokens": 20,
|
||||
}
|
||||
response_tensor = ppo_trainer.generate(list(query_tensor), return_prompt=False, **generation_kwargs)
|
||||
response_txt = tokenizer.decode(response_tensor[0])
|
||||
|
||||
# 5. define a reward for response
|
||||
# (this could be any reward such as human feedback or output from another model)
|
||||
reward = [torch.tensor(1.0, device=model.pretrained_model.device)]
|
||||
|
||||
# 6. train model with ppo
|
||||
train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward)
|
@ -1,200 +0,0 @@
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
python examples/scripts/ppo.py \
|
||||
--log_with=wandb
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator, PartialState
|
||||
from datasets import load_dataset
|
||||
from peft import LoraConfig
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoTokenizer, HfArgumentParser, is_torch_npu_available, is_torch_xpu_available, pipeline
|
||||
|
||||
from trl import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead, PPOConfig, PPOTrainer, set_seed
|
||||
from trl.core import LengthSampler
|
||||
|
||||
|
||||
tqdm.pandas()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
use_seq2seq: bool = field(default=False, metadata={"help": "whether to use seq2seq"})
|
||||
trust_remote_code: bool = field(default=False, metadata={"help": "Enable `trust_remote_code`"})
|
||||
|
||||
# LoraConfig
|
||||
use_peft: bool = field(default=False, metadata={"help": "whether to use peft"})
|
||||
lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"})
|
||||
lora_r: Optional[int] = field(default=16, metadata={"help": "the lora r parameter"})
|
||||
|
||||
|
||||
parser = HfArgumentParser((ScriptArguments, PPOConfig))
|
||||
script_args, ppo_config = parser.parse_args_into_dataclasses()
|
||||
|
||||
# We then define the arguments to pass to the sentiment analysis pipeline.
|
||||
# We set `return_all_scores` to True to get the sentiment score for each token.
|
||||
sent_kwargs = {"return_all_scores": True, "function_to_apply": "none", "batch_size": 16}
|
||||
|
||||
trl_model_class = (
|
||||
AutoModelForCausalLMWithValueHead if not script_args.use_seq2seq else AutoModelForSeq2SeqLMWithValueHead
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(ppo_config.model_name)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
|
||||
# Below is an example function to build the dataset. In our case, we use the IMDB dataset
|
||||
# from the `datasets` library. One should customize this function to train the model on
|
||||
# its own dataset.
|
||||
def build_dataset(query_dataset, dataset_num_proc, input_min_text_length=2, input_max_text_length=8):
|
||||
"""
|
||||
Build dataset for training. This builds the dataset from `load_dataset`, one should
|
||||
customize this function to train the model on its own dataset.
|
||||
|
||||
Args:
|
||||
query_dataset (`str`):
|
||||
The name of the dataset to be loaded.
|
||||
|
||||
Returns:
|
||||
dataloader (`torch.utils.data.DataLoader`):
|
||||
The dataloader for the dataset.
|
||||
"""
|
||||
# load imdb with datasets
|
||||
dataset = load_dataset(query_dataset, split="train")
|
||||
dataset = dataset.rename_columns({"text": "review"})
|
||||
dataset = dataset.filter(lambda x: len(x["review"]) > 200, num_proc=dataset_num_proc)
|
||||
|
||||
input_size = LengthSampler(input_min_text_length, input_max_text_length)
|
||||
|
||||
def tokenize(sample):
|
||||
sample["input_ids"] = tokenizer.encode(sample["review"])[: input_size()]
|
||||
sample["query"] = tokenizer.decode(sample["input_ids"])
|
||||
return sample
|
||||
|
||||
dataset = dataset.map(tokenize, num_proc=dataset_num_proc)
|
||||
dataset.set_format(type="torch")
|
||||
return dataset
|
||||
|
||||
|
||||
# We retrieve the dataloader by calling the `build_dataset` function.
|
||||
# Compute that only on the main process for faster data processing.
|
||||
# see: https://github.com/huggingface/trl/pull/1255
|
||||
with PartialState().local_main_process_first():
|
||||
dataset = build_dataset(ppo_config.query_dataset, ppo_config.dataset_num_proc)
|
||||
|
||||
|
||||
def collator(data):
|
||||
return {key: [d[key] for d in data] for key in data[0]}
|
||||
|
||||
|
||||
# set seed before initializing value head for deterministic eval
|
||||
set_seed(ppo_config.seed)
|
||||
|
||||
# Now let's build the model, the reference model, and the tokenizer.
|
||||
if not script_args.use_peft:
|
||||
ref_model = trl_model_class.from_pretrained(ppo_config.model_name, trust_remote_code=script_args.trust_remote_code)
|
||||
device_map = None
|
||||
peft_config = None
|
||||
else:
|
||||
peft_config = LoraConfig(
|
||||
r=script_args.lora_r,
|
||||
lora_alpha=script_args.lora_alpha,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
ref_model = None
|
||||
# Copy the model to each device
|
||||
device_map = {"": Accelerator().local_process_index}
|
||||
|
||||
model = trl_model_class.from_pretrained(
|
||||
ppo_config.model_name,
|
||||
trust_remote_code=script_args.trust_remote_code,
|
||||
device_map=device_map,
|
||||
peft_config=peft_config,
|
||||
)
|
||||
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(ppo_config.model_name)
|
||||
|
||||
# Some tokenizers like GPT-2's don't have a padding token by default, so we set one here.
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
|
||||
# We then build the PPOTrainer, passing the model, the reference model, the tokenizer
|
||||
ppo_trainer = PPOTrainer(ppo_config, model, ref_model, tokenizer, dataset=dataset, data_collator=collator)
|
||||
|
||||
# We then build the sentiment analysis pipeline, passing the model name and the
|
||||
# sentiment analysis pipeline arguments. Let's also make sure to set the device
|
||||
# to the same device as the PPOTrainer.
|
||||
device = ppo_trainer.accelerator.device
|
||||
if ppo_trainer.accelerator.num_processes == 1:
|
||||
if is_torch_xpu_available():
|
||||
device = "xpu:0"
|
||||
elif is_torch_npu_available():
|
||||
device = "npu:0"
|
||||
else:
|
||||
device = 0 if torch.cuda.is_available() else "cpu" # to avoid a `pipeline` bug
|
||||
ds_plugin = ppo_trainer.accelerator.state.deepspeed_plugin
|
||||
task, model_name = ppo_config.reward_model.split(":")
|
||||
if ds_plugin is not None and ds_plugin.is_zero3_init_enabled():
|
||||
with ds_plugin.zero3_init_context_manager(enable=False):
|
||||
sentiment_pipe = pipeline(task, model=model_name, device=device)
|
||||
else:
|
||||
sentiment_pipe = pipeline(task, model=model_name, device=device)
|
||||
|
||||
# Some tokenizers like GPT-2's don't have a padding token by default, so we set one here.
|
||||
if sentiment_pipe.tokenizer.pad_token_id is None:
|
||||
sentiment_pipe.tokenizer.pad_token_id = tokenizer.pad_token_id
|
||||
|
||||
if sentiment_pipe.model.config.pad_token_id is None:
|
||||
sentiment_pipe.model.config.pad_token_id = tokenizer.pad_token_id
|
||||
|
||||
# We then define the arguments to pass to the `generate` function. These arguments
|
||||
# are passed to the `generate` function of the PPOTrainer, which is a wrapper around
|
||||
# the `generate` function of the trained model.
|
||||
generation_kwargs = {
|
||||
"min_length": -1,
|
||||
"top_k": 0.0,
|
||||
"top_p": 1.0,
|
||||
"do_sample": True,
|
||||
"pad_token_id": tokenizer.eos_token_id,
|
||||
"max_new_tokens": 32,
|
||||
}
|
||||
|
||||
for batch in tqdm(ppo_trainer.dataloader):
|
||||
query_tensors = batch["input_ids"]
|
||||
|
||||
# Get response from gpt2
|
||||
response_tensors, ref_response_tensors = ppo_trainer.generate(
|
||||
query_tensors, return_prompt=False, generate_ref_response=True, **generation_kwargs
|
||||
)
|
||||
batch["response"] = tokenizer.batch_decode(response_tensors)
|
||||
batch["ref_response"] = tokenizer.batch_decode(ref_response_tensors)
|
||||
|
||||
# Compute sentiment score
|
||||
texts = [q + r for q, r in zip(batch["query"], batch["response"])]
|
||||
pipe_outputs = sentiment_pipe(texts, **sent_kwargs)
|
||||
rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs]
|
||||
ref_texts = [q + r for q, r in zip(batch["query"], batch["ref_response"])]
|
||||
ref_pipe_outputs = sentiment_pipe(ref_texts, **sent_kwargs)
|
||||
ref_rewards = [torch.tensor(output[1]["score"]) for output in ref_pipe_outputs]
|
||||
batch["ref_rewards"] = ref_rewards
|
||||
|
||||
# Run PPO step
|
||||
stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
|
||||
ppo_trainer.log_stats(stats, batch, rewards, columns_to_log=["query", "response", "ref_response", "ref_rewards"])
|
@ -23,7 +23,7 @@ from transformers import (
|
||||
HfArgumentParser,
|
||||
)
|
||||
|
||||
from trl import ModelConfig, PPOv2Config, PPOv2Trainer
|
||||
from trl import ModelConfig, PPOConfig, PPOTrainer
|
||||
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE
|
||||
|
||||
|
||||
@ -55,7 +55,7 @@ accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = HfArgumentParser((PPOv2Config, ModelConfig))
|
||||
parser = HfArgumentParser((PPOConfig, ModelConfig))
|
||||
training_args, model_config = parser.parse_args_into_dataclasses()
|
||||
# remove output_dir if exists
|
||||
shutil.rmtree(training_args.output_dir, ignore_errors=True)
|
||||
@ -118,7 +118,7 @@ if __name__ == "__main__":
|
||||
################
|
||||
# Training
|
||||
################
|
||||
trainer = PPOv2Trainer(
|
||||
trainer = PPOTrainer(
|
||||
config=training_args,
|
||||
processing_class=tokenizer,
|
||||
policy=policy,
|
||||
|
@ -23,7 +23,7 @@ from transformers import (
|
||||
HfArgumentParser,
|
||||
)
|
||||
|
||||
from trl import ModelConfig, PPOv2Config, PPOv2Trainer
|
||||
from trl import ModelConfig, PPOConfig, PPOTrainer
|
||||
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE
|
||||
|
||||
|
||||
@ -58,7 +58,7 @@ accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = HfArgumentParser((PPOv2Config, ModelConfig))
|
||||
parser = HfArgumentParser((PPOConfig, ModelConfig))
|
||||
training_args, model_config = parser.parse_args_into_dataclasses()
|
||||
# remove output_dir if exists
|
||||
shutil.rmtree(training_args.output_dir, ignore_errors=True)
|
||||
@ -123,7 +123,7 @@ if __name__ == "__main__":
|
||||
################
|
||||
# Training
|
||||
################
|
||||
trainer = PPOv2Trainer(
|
||||
trainer = PPOTrainer(
|
||||
config=training_args,
|
||||
processing_class=tokenizer,
|
||||
policy=policy,
|
||||
|
@ -1,163 +0,0 @@
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from accelerate import PartialState
|
||||
from datasets import load_dataset
|
||||
from peft import LoraConfig
|
||||
from tqdm import tqdm
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
BitsAndBytesConfig,
|
||||
HfArgumentParser,
|
||||
is_torch_npu_available,
|
||||
is_torch_xpu_available,
|
||||
)
|
||||
|
||||
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer
|
||||
from trl.core import LengthSampler
|
||||
|
||||
|
||||
input_min_text_length = 6
|
||||
input_max_text_length = 12
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
"""
|
||||
The name of the Casual LM model we wish to fine with PPO
|
||||
"""
|
||||
|
||||
model_name: Optional[str] = field(default="huggyllama/llama-7b", metadata={"help": "the model name"})
|
||||
dataset_name: Optional[str] = field(default="Anthropic/hh-rlhf", metadata={"help": "the dataset name"})
|
||||
rm_adapter: Optional[str] = field(
|
||||
default="trl-lib/llama-7b-hh-rm-adapter", metadata={"help": "the rm adapter name"}
|
||||
)
|
||||
log_with: Optional[str] = field(default=None, metadata={"help": "use 'wandb' to log with wandb"})
|
||||
use_safetensors: Optional[bool] = field(default=False, metadata={"help": "Use safetensors"})
|
||||
seed: Optional[int] = field(default=0, metadata={"help": "the random seed"})
|
||||
use_score_scaling: Optional[bool] = field(default=False, metadata={"help": "Use score scaling"})
|
||||
use_score_norm: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Use score normalization. Only applicable if use_score_scaling is True"}
|
||||
)
|
||||
score_clip: Optional[float] = field(default=None, metadata={"help": "Score clipping"})
|
||||
dataset_num_proc: Optional[int] = field(
|
||||
default=None, metadata={"help": "The number of workers to use to tokenize the data"}
|
||||
)
|
||||
|
||||
|
||||
parser = HfArgumentParser(ScriptArguments)
|
||||
script_args = parser.parse_args_into_dataclasses()[0]
|
||||
|
||||
|
||||
def create_and_prepare_dataset(tokenizer, num_proc):
|
||||
dataset = load_dataset(script_args.dataset_name, split="train[:1%]")
|
||||
|
||||
input_size = LengthSampler(input_min_text_length, input_max_text_length)
|
||||
|
||||
def tokenize(example):
|
||||
text_size = input_size()
|
||||
example["input_ids"] = tokenizer.encode(example["chosen"])[:text_size]
|
||||
example["query"] = tokenizer.decode(example["input_ids"])
|
||||
return example
|
||||
|
||||
dataset = dataset.map(tokenize, batched=False, num_proc=num_proc)
|
||||
dataset.set_format("torch")
|
||||
return dataset
|
||||
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
lora_dropout=0.05,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
nf4_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16
|
||||
)
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
script_args.model_name,
|
||||
device_map={"": "xpu:0"} if is_torch_xpu_available() else {"": "npu:0"} if is_torch_npu_available else {"": 0},
|
||||
peft_config=lora_config,
|
||||
quantization_config=nf4_config,
|
||||
reward_adapter=script_args.rm_adapter,
|
||||
use_safetensors=script_args.use_safetensors,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(script_args.model_name)
|
||||
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
# Compute that only on the main process for faster data processing.
|
||||
# see: https://github.com/huggingface/trl/pull/1255
|
||||
with PartialState().local_main_process_first():
|
||||
dataset = create_and_prepare_dataset(tokenizer, script_args.dataset_num_proc)
|
||||
|
||||
|
||||
def collator(data):
|
||||
return {key: [d[key] for d in data] for key in data[0]}
|
||||
|
||||
|
||||
config = PPOConfig(
|
||||
model_name=script_args.model_name,
|
||||
log_with=script_args.log_with,
|
||||
learning_rate=1e-5,
|
||||
batch_size=8,
|
||||
mini_batch_size=2,
|
||||
gradient_accumulation_steps=2,
|
||||
optimize_cuda_cache=True,
|
||||
seed=script_args.seed,
|
||||
use_score_scaling=script_args.use_score_scaling,
|
||||
use_score_norm=script_args.use_score_norm,
|
||||
score_clip=script_args.score_clip,
|
||||
)
|
||||
|
||||
ppo_trainer = PPOTrainer(
|
||||
config,
|
||||
model,
|
||||
ref_model=None,
|
||||
tokenizer=tokenizer,
|
||||
dataset=dataset,
|
||||
data_collator=collator,
|
||||
)
|
||||
|
||||
generation_kwargs = {
|
||||
"top_k": 0.0,
|
||||
"top_p": 0.9,
|
||||
"do_sample": True,
|
||||
"pad_token_id": tokenizer.pad_token_id,
|
||||
"max_new_tokens": 32,
|
||||
}
|
||||
|
||||
for _epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
|
||||
question_tensors = batch["input_ids"]
|
||||
|
||||
response_tensors = ppo_trainer.generate(
|
||||
question_tensors,
|
||||
return_prompt=False,
|
||||
**generation_kwargs,
|
||||
)
|
||||
batch["response"] = tokenizer.batch_decode(response_tensors, skip_special_tokens=True)
|
||||
|
||||
# Compute reward score
|
||||
texts = [q + r for q, r in zip(batch["query"], batch["response"])]
|
||||
inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt").to(ppo_trainer.accelerator.device)
|
||||
raw_rewards = ppo_trainer.accelerator.unwrap_model(ppo_trainer.model).compute_reward_score(**inputs)
|
||||
rewards = [raw_rewards[i, -1, 1] for i in range(len(raw_rewards))] # take last token
|
||||
|
||||
# Run PPO step
|
||||
stats = ppo_trainer.step(question_tensors, response_tensors, rewards)
|
||||
ppo_trainer.log_stats(stats, batch, rewards)
|
@ -1,23 +0,0 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import subprocess
|
||||
|
||||
|
||||
def test_hello_world():
|
||||
subprocess.run(
|
||||
"python examples/hello_world.py",
|
||||
shell=True,
|
||||
check=True,
|
||||
)
|
@ -1,149 +0,0 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import sys
|
||||
import unittest
|
||||
from functools import partial
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.utils import import_utils
|
||||
|
||||
|
||||
class DummyDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, query_data, response_data):
|
||||
self.query_data = query_data
|
||||
self.response_data = response_data
|
||||
|
||||
def __len__(self):
|
||||
return len(self.query_data)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.query_data[idx], self.response_data[idx]
|
||||
|
||||
|
||||
EXPECTED_STATS = [
|
||||
"objective/kl",
|
||||
"objective/kl_dist",
|
||||
"objective/logprobs",
|
||||
"objective/ref_logprobs",
|
||||
"objective/kl_coef",
|
||||
"objective/entropy",
|
||||
"ppo/mean_non_score_reward",
|
||||
"ppo/loss/policy",
|
||||
"ppo/loss/value",
|
||||
"ppo/loss/total",
|
||||
"ppo/policy/entropy",
|
||||
"ppo/policy/approxkl",
|
||||
"ppo/policy/policykl",
|
||||
"ppo/policy/clipfrac",
|
||||
"ppo/policy/advantages",
|
||||
"ppo/policy/advantages_mean",
|
||||
"ppo/policy/ratio",
|
||||
"ppo/returns/mean",
|
||||
"ppo/returns/var",
|
||||
"ppo/val/vpred",
|
||||
"ppo/val/error",
|
||||
"ppo/val/clipfrac",
|
||||
"ppo/val/mean",
|
||||
"ppo/val/var",
|
||||
"ppo/val/var_explained",
|
||||
"time/ppo/forward_pass",
|
||||
"time/ppo/compute_rewards",
|
||||
"time/ppo/optimize_step",
|
||||
"time/ppo/calc_stats",
|
||||
"time/ppo/total",
|
||||
"ppo/learning_rate",
|
||||
]
|
||||
|
||||
|
||||
class TestPeftDependancy(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.causal_lm_model_id = "trl-internal-testing/tiny-random-GPTNeoXForCausalLM"
|
||||
self.seq_to_seq_model_id = "trl-internal-testing/tiny-random-T5ForConditionalGeneration"
|
||||
|
||||
def test_no_peft(self):
|
||||
_peft_available = import_utils._peft_available
|
||||
import_utils._peft_available = False # required so that is_peft_available() returns False
|
||||
with patch.dict(sys.modules, {"peft": None}):
|
||||
from trl import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead
|
||||
|
||||
# Check that loading a model with `peft` will raise an error
|
||||
with pytest.raises(ModuleNotFoundError):
|
||||
import peft # noqa: F401
|
||||
|
||||
_trl_model = AutoModelForCausalLMWithValueHead.from_pretrained(self.causal_lm_model_id)
|
||||
_trl_seq2seq_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(self.seq_to_seq_model_id)
|
||||
import_utils._peft_available = _peft_available
|
||||
|
||||
def test_imports_no_peft(self):
|
||||
_peft_available = import_utils._peft_available
|
||||
import_utils._peft_available = False # required so that is_peft_available() returns False
|
||||
with patch.dict(sys.modules, {"peft": None}):
|
||||
from trl import ( # noqa: F401
|
||||
AutoModelForCausalLMWithValueHead,
|
||||
AutoModelForSeq2SeqLMWithValueHead,
|
||||
PPOConfig,
|
||||
PPOTrainer,
|
||||
PreTrainedModelWrapper,
|
||||
)
|
||||
import_utils._peft_available = _peft_available
|
||||
|
||||
def test_ppo_trainer_no_peft(self):
|
||||
_peft_available = import_utils._peft_available
|
||||
import_utils._peft_available = False # required so that is_peft_available() returns False
|
||||
with patch.dict(sys.modules, {"peft": None}):
|
||||
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer
|
||||
|
||||
ppo_model_id = "trl-internal-testing/dummy-GPT2-correct-vocab"
|
||||
|
||||
trl_model = AutoModelForCausalLMWithValueHead.from_pretrained(ppo_model_id)
|
||||
tokenizer = AutoTokenizer.from_pretrained(ppo_model_id)
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
|
||||
ppo_config = PPOConfig(batch_size=2, mini_batch_size=1, log_with=None)
|
||||
|
||||
dummy_dataset = DummyDataset(
|
||||
[torch.LongTensor([0, 1, 0, 1, 0, 1]), torch.LongTensor([0, 1, 0, 1, 0, 1])],
|
||||
[torch.LongTensor([1, 0, 1, 0, 1, 0]), torch.LongTensor([0, 1, 0, 1, 0, 1])],
|
||||
)
|
||||
|
||||
ppo_trainer = PPOTrainer(
|
||||
config=ppo_config,
|
||||
model=trl_model,
|
||||
ref_model=None,
|
||||
tokenizer=tokenizer,
|
||||
dataset=dummy_dataset,
|
||||
)
|
||||
ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False)
|
||||
dummy_dataloader = ppo_trainer.dataloader
|
||||
|
||||
for query_tensor, response_tensor in dummy_dataloader:
|
||||
# define a reward for response
|
||||
# (this could be any reward such as human feedback or output from another model)
|
||||
reward = [torch.tensor(1.0), torch.tensor(0.0)]
|
||||
# train model
|
||||
train_stats = ppo_trainer.step(list(query_tensor), list(response_tensor), reward)
|
||||
break
|
||||
|
||||
# check gradients are not None
|
||||
for _, param in trl_model.named_parameters():
|
||||
if param.requires_grad:
|
||||
assert param.grad is not None
|
||||
|
||||
# check expected stats
|
||||
for stat in EXPECTED_STATS:
|
||||
assert stat in train_stats
|
||||
import_utils._peft_available = _peft_available
|
File diff suppressed because it is too large
Load Diff
@ -1,63 +0,0 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import platform
|
||||
import subprocess
|
||||
|
||||
|
||||
def test():
|
||||
command = """\
|
||||
python examples/scripts/ppo/ppo.py \
|
||||
--learning_rate 3e-6 \
|
||||
--output_dir models/minimal/ppo \
|
||||
--per_device_train_batch_size 4 \
|
||||
--gradient_accumulation_steps 1 \
|
||||
--total_episodes 10 \
|
||||
--model_name_or_path EleutherAI/pythia-14m \
|
||||
--missing_eos_penalty 1.0 \
|
||||
--save_strategy no \
|
||||
--stop_token eos
|
||||
"""
|
||||
if platform.system() == "Windows":
|
||||
# windows CI does not work with subprocesses for some reason
|
||||
# e.g., https://github.com/huggingface/trl/actions/runs/9600036224/job/26475286210?pr=1743
|
||||
return
|
||||
subprocess.run(
|
||||
command,
|
||||
shell=True,
|
||||
check=True,
|
||||
)
|
||||
|
||||
|
||||
def test_num_train_epochs():
|
||||
command = """\
|
||||
python examples/scripts/ppo/ppo.py \
|
||||
--learning_rate 3e-6 \
|
||||
--output_dir models/minimal/ppo \
|
||||
--per_device_train_batch_size 4 \
|
||||
--gradient_accumulation_steps 1 \
|
||||
--num_train_epochs 0.003 \
|
||||
--model_name_or_path EleutherAI/pythia-14m \
|
||||
--missing_eos_penalty 1.0 \
|
||||
--save_strategy no \
|
||||
--stop_token eos
|
||||
"""
|
||||
if platform.system() == "Windows":
|
||||
# windows CI does not work with subprocesses for some reason
|
||||
# e.g., https://github.com/huggingface/trl/actions/runs/9600036224/job/26475286210?pr=1743
|
||||
return
|
||||
subprocess.run(
|
||||
command,
|
||||
shell=True,
|
||||
check=True,
|
||||
)
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -11,28 +11,15 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal, Optional
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
import tyro
|
||||
from transformers import is_wandb_available
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from trl.trainer.utils import exact_div
|
||||
|
||||
from ..core import flatten_dict
|
||||
|
||||
|
||||
JSONDict = Annotated[Optional[dict], tyro.conf.arg(metavar="JSON", constructor=json.loads)]
|
||||
from ..trainer.utils import OnPolicyConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class PPOConfig:
|
||||
class PPOConfig(OnPolicyConfig):
|
||||
r"""
|
||||
Configuration class for the [`PPOTrainer`].
|
||||
|
||||
@ -41,199 +28,35 @@ class PPOConfig:
|
||||
command line.
|
||||
|
||||
Parameters:
|
||||
exp_name (`str`, *optional*, defaults to `os.path.basename(__file__)[: -len(".py")]`):
|
||||
exp_name (`str`, *optional*, defaults to `os.path.basename(__file__)[:-3]`):
|
||||
Name of this experiment.
|
||||
seed (`int`, *optional*, defaults to `0`):
|
||||
Random seed.
|
||||
log_with (`Optional[Literal["wandb", "tensorboard"]]`, *optional*, defaults to `None`):
|
||||
Log with either `"wandb"` or `"tensorboard"`. Check
|
||||
[tracking](https://huggingface.co/docs/accelerate/usage_guides/tracking) for more details.
|
||||
task_name (`Optional[str]`, *optional*, defaults to `None`):
|
||||
Name of task to use - used only for tracking purposes.
|
||||
model_name (`Optional[str]`, *optional*, defaults to `"gpt2"`):
|
||||
Name of model to use - used only for tracking purposes.
|
||||
query_dataset (`Optional[str]`, *optional*, defaults to `"stanfordnlp/imdb"`):
|
||||
Name of dataset to query - used only for tracking purposes.
|
||||
reward_model (`Optional[str]`, *optional*, defaults to `"sentiment-analysis:lvwerra/distilbert-imdb"`):
|
||||
Reward model to use - used only for tracking purposes.
|
||||
remove_unused_columns (`bool`, *optional*, defaults to `True`):
|
||||
Remove unused columns from the dataset.
|
||||
tracker_kwargs (`JSONDict`, *optional*, defaults to `{}`):
|
||||
Keyword arguments for the tracker (e.g. `python ppo.py --tracker_kwargs='{"wandb": {"entity": "my_wandb_entity", "name": "my_exp_name"}}'`.
|
||||
accelerator_kwargs (`JSONDict`, *optional*, defaults to `{}`):
|
||||
Keyword arguments for the accelerator.
|
||||
project_kwargs (`JSONDict`, *optional*, defaults to `{}`):
|
||||
Keyword arguments for the accelerator project config (e.g. `logging_dir`).
|
||||
tracker_project_name (`str`, *optional*, defaults to `"trl"`):
|
||||
Name of project to use for tracking.
|
||||
push_to_hub_if_best_kwargs (`JSONDict`, *optional*, defaults to `{}`):
|
||||
Keyword arguments for pushing model to the hub during training (e.g. repo_id).
|
||||
steps (`int`, *optional*, defaults to `20000`):
|
||||
Number of training steps.
|
||||
learning_rate (`float`, *optional*, defaults to `1.41e-5`):
|
||||
Learning rate for the optimizer.
|
||||
adap_kl_ctrl (`bool`, *optional*, defaults to `True`):
|
||||
Use adaptive KL control, otherwise linear.
|
||||
init_kl_coef (`Optional[float]`, *optional*, defaults to `0.2`):
|
||||
Initial KL penalty coefficient (used for adaptive and linear control).
|
||||
kl_penalty (`Literal["kl", "abs", "mse", "full"]`, *optional*, defaults to `"kl"`):
|
||||
kl penalty options. Possible values are:
|
||||
|
||||
- `"kl"`: model_logp - ref_logp
|
||||
- `"abs"`: abs(kl)
|
||||
- `"mse"`: mean squared error mse(kl)
|
||||
- `"full"`: the actual kl for all tokens in the distribution.
|
||||
|
||||
target (`float`, *optional*, defaults to `6.0`):
|
||||
Target KL value for adaptive KL control.
|
||||
horizon (`float`, *optional*, defaults to `10000.0`):
|
||||
Horizon for adaptive KL control.
|
||||
gamma (`float`, *optional*, defaults to `1.0`):
|
||||
Gamma parameter for advantage calculation.
|
||||
lam (`float`, *optional*, defaults to `0.95`):
|
||||
Lambda parameter for advantage calculation.
|
||||
cliprange (`float`, *optional*, defaults to `0.2`):
|
||||
Range for clipping in PPO policy gradient loss.
|
||||
cliprange_value (`float`, *optional*, defaults to `0.2`):
|
||||
Range for clipping values in loss calculation.
|
||||
vf_coef (`float`, *optional*, defaults to `0.1`):
|
||||
Scaling factor for value loss.
|
||||
batch_size (`int`, *optional*, defaults to `128`):
|
||||
Number of samples per optimisation step.
|
||||
forward_batch_size (`Optional[int]`, *optional*, defaults to `None`):
|
||||
DEPRECATED: use `mini_batch_size` instead, which does the same thing.
|
||||
mini_batch_size (`int`, *optional*, defaults to `128`):
|
||||
Number of samples optimized in each mini batch.
|
||||
gradient_accumulation_steps (`int`, *optional*, defaults to `1`):
|
||||
Number of gradient accumulation steps.
|
||||
world_size (`Optional[int]`, *optional*, defaults to `None`):
|
||||
Number of processes to use for distributed training.
|
||||
ppo_epochs (`int`, *optional*, defaults to `4`):
|
||||
Number of optimisation epochs per batch of samples.
|
||||
optimize_device_cache (`bool`, *optional*, defaults to `False`):
|
||||
Optimize device cache for slightly more memory-efficient training.
|
||||
early_stopping (`bool`, *optional*, defaults to `False`):
|
||||
Whether to stop the PPO optimization loop early is the KL too high.
|
||||
target_kl (`float`, *optional*, defaults to `1.0`):
|
||||
Stop early if we exceed this value by over 50%.
|
||||
compare_steps (`int`, *optional*, defaults to `1`):
|
||||
Compare the current step with the previous `compare_steps` steps.
|
||||
ratio_threshold (`float`, *optional*, defaults to `10.0`):
|
||||
Skip mini-batches with high PPO ratios that can cause loss spikes.
|
||||
use_score_scaling (`bool`, *optional*, defaults to `False`):
|
||||
Use score scaling.
|
||||
use_score_norm (`bool`, *optional*, defaults to `False`):
|
||||
Use score normalization. Only applicable if `use_score_scaling` is True.
|
||||
score_clip (`Optional[float]`, *optional*, defaults to `None`):
|
||||
Score clipping.
|
||||
reward_model_path (`str`, *optional*, defaults to `"EleutherAI/pythia-160m"`):
|
||||
Path to the reward model.
|
||||
num_ppo_epochs (`int`, *optional*, defaults to `4`):
|
||||
Number of epochs to train.
|
||||
whiten_rewards (`bool`, *optional*, defaults to `False`):
|
||||
Whiten the rewards before computing advantages.
|
||||
is_encoder_decoder (`Optional[bool]`, *optional*, defaults to `None`):
|
||||
When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument,
|
||||
you need to specify if the model returned by the callable is an encoder-decoder model.
|
||||
is_peft_model (`Optional[bool]`, *optional*, defaults to `None`):
|
||||
Whether the model is a PEFT model.
|
||||
backward_batch_size (`Optional[int]`, *optional*, defaults to `None`):
|
||||
Number of samples optimized in an `optimizer.step()` call.
|
||||
global_backward_batch_size (`Optional[int]`, *optional*, defaults to `None`):
|
||||
Effective `backward_batch_size` across all processes.
|
||||
global_batch_size (`Optional[int]`, *optional*, defaults to `None`):
|
||||
Effective `batch_size` across all processes.
|
||||
dataset_num_proc (`Optional[int]`, *optional*, defaults to `None`):
|
||||
Number of processes to use for processing the dataset.
|
||||
Whether to whiten the rewards.
|
||||
kl_coef (`float`, *optional*, defaults to `0.05`):
|
||||
KL coefficient.
|
||||
cliprange (`float`, *optional*, defaults to `0.2`):
|
||||
Clip range.
|
||||
vf_coef (`float`, *optional*, defaults to `0.1`):
|
||||
Value function coefficient.
|
||||
cliprange_value (`float`, *optional*, defaults to `0.2`):
|
||||
Clip range for the value function.
|
||||
gamma (`float`, *optional*, defaults to `1.0`):
|
||||
Discount factor.
|
||||
lam (`float`, *optional*, defaults to `0.95`):
|
||||
Lambda value for GAE.
|
||||
"""
|
||||
|
||||
exp_name: str = os.path.basename(sys.argv[0])[: -len(".py")]
|
||||
seed: int = 0
|
||||
log_with: Optional[Literal["wandb", "tensorboard"]] = None
|
||||
task_name: Optional[str] = None
|
||||
model_name: str = "gpt2"
|
||||
query_dataset: str = "stanfordnlp/imdb"
|
||||
reward_model: str = "sentiment-analysis:lvwerra/distilbert-imdb"
|
||||
remove_unused_columns: bool = True
|
||||
tracker_kwargs: JSONDict = field(default_factory=dict)
|
||||
accelerator_kwargs: JSONDict = field(default_factory=dict)
|
||||
project_kwargs: JSONDict = field(default_factory=dict)
|
||||
tracker_project_name: str = "trl"
|
||||
push_to_hub_if_best_kwargs: JSONDict = field(default_factory=dict)
|
||||
steps: int = 20000
|
||||
learning_rate: float = 1.41e-5
|
||||
adap_kl_ctrl: bool = True
|
||||
init_kl_coef: float = 0.2
|
||||
kl_penalty: Literal["kl", "abs", "mse", "full"] = "kl"
|
||||
target: float = 6.0
|
||||
horizon: float = 10000.0
|
||||
exp_name: str = os.path.basename(__file__)[: -len(".py")]
|
||||
reward_model_path: str = "EleutherAI/pythia-160m"
|
||||
num_ppo_epochs: int = 4
|
||||
whiten_rewards: bool = False
|
||||
kl_coef: float = 0.05
|
||||
cliprange: float = 0.2
|
||||
vf_coef: float = 0.1
|
||||
cliprange_value: float = 0.2
|
||||
gamma: float = 1.0
|
||||
lam: float = 0.95
|
||||
cliprange: float = 0.2
|
||||
cliprange_value: float = 0.2
|
||||
vf_coef: float = 0.1
|
||||
batch_size: int = 128
|
||||
forward_batch_size: Optional[int] = None
|
||||
mini_batch_size: int = 128
|
||||
gradient_accumulation_steps: int = 1
|
||||
world_size: tyro.conf.Suppress[int] = None
|
||||
ppo_epochs: int = 4
|
||||
max_grad_norm: Optional[float] = None
|
||||
optimize_cuda_cache: Optional[bool] = None
|
||||
optimize_device_cache: bool = False
|
||||
early_stopping: bool = False
|
||||
target_kl: float = 1.0
|
||||
compare_steps: int = 1
|
||||
ratio_threshold: float = 10.0
|
||||
use_score_scaling: bool = False
|
||||
use_score_norm: bool = False
|
||||
score_clip: Optional[float] = None
|
||||
whiten_rewards: bool = False
|
||||
gradient_checkpointing: bool = False
|
||||
is_encoder_decoder: Optional[tyro.conf.Suppress[bool]] = None
|
||||
is_peft_model: Optional[tyro.conf.Suppress[bool]] = None
|
||||
backward_batch_size: tyro.conf.Suppress[int] = None
|
||||
global_backward_batch_size: Optional[tyro.conf.Suppress[int]] = None
|
||||
global_batch_size: tyro.conf.Suppress[int] = None
|
||||
dataset_num_proc: Optional[int] = None
|
||||
|
||||
if optimize_cuda_cache is not None:
|
||||
warnings.warn(
|
||||
"The `optimize_cuda_cache` argument will be deprecated soon, please use `optimize_device_cache` instead."
|
||||
)
|
||||
|
||||
if optimize_device_cache is True:
|
||||
raise ValueError("Both `optimize_device_cache` and `optimize_cuda_cache` were provided")
|
||||
|
||||
optimize_device_cache = optimize_cuda_cache
|
||||
|
||||
def __post_init__(self):
|
||||
warnings.warn(
|
||||
"`PPOConfig` is deprecated and will be removed in the future. Please use `PPOv2Config` with `PPOv2Trainer` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
if self.forward_batch_size is not None:
|
||||
warnings.warn(
|
||||
"Note that using `forward_batch_size` is deprecated, use `mini_batch_size` instead. By setting it you overwrite `mini_batch_size` which affects both the batch size during forward passes and also the mini batch size for PPO optimization."
|
||||
)
|
||||
self.mini_batch_size = self.forward_batch_size
|
||||
|
||||
self.backward_batch_size = self.mini_batch_size * self.gradient_accumulation_steps
|
||||
exact_div(
|
||||
self.batch_size,
|
||||
self.backward_batch_size,
|
||||
"`batch_size` must be a multiple of `mini_batch_size * gradient_accumulation_steps`",
|
||||
)
|
||||
|
||||
# check if wandb is installed
|
||||
if self.log_with == "wandb":
|
||||
# raise error if wandb is not installed
|
||||
if not is_wandb_available():
|
||||
raise ImportError(
|
||||
"Please install wandb to use wandb logging. You can do this by running `pip install wandb`."
|
||||
)
|
||||
|
||||
self.total_ppo_epochs = int(np.ceil(self.steps / self.batch_size))
|
||||
assert self.kl_penalty in ["kl", "abs", "mse", "full"]
|
||||
|
||||
def to_dict(self):
|
||||
output_dict = {}
|
||||
for key, value in self.__dict__.items():
|
||||
output_dict[key] = value
|
||||
return flatten_dict(output_dict)
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -12,51 +12,17 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
import warnings
|
||||
|
||||
from ..trainer.utils import OnPolicyConfig
|
||||
from .ppo_config import PPOConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class PPOv2Config(OnPolicyConfig):
|
||||
r"""
|
||||
Configuration class for the [`PPOv2Trainer`].
|
||||
|
||||
Using [`~transformers.HfArgumentParser`] we can turn this class into
|
||||
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
||||
command line.
|
||||
|
||||
Parameters:
|
||||
exp_name (`str`, *optional*, defaults to `os.path.basename(__file__)[:-3]`):
|
||||
Name of this experiment.
|
||||
reward_model_path (`str`, *optional*, defaults to `"EleutherAI/pythia-160m"`):
|
||||
Path to the reward model.
|
||||
num_ppo_epochs (`int`, *optional*, defaults to `4`):
|
||||
Number of epochs to train.
|
||||
whiten_rewards (`bool`, *optional*, defaults to `False`):
|
||||
Whether to whiten the rewards.
|
||||
kl_coef (`float`, *optional*, defaults to `0.05`):
|
||||
KL coefficient.
|
||||
cliprange (`float`, *optional*, defaults to `0.2`):
|
||||
Clip range.
|
||||
vf_coef (`float`, *optional*, defaults to `0.1`):
|
||||
Value function coefficient.
|
||||
cliprange_value (`float`, *optional*, defaults to `0.2`):
|
||||
Clip range for the value function.
|
||||
gamma (`float`, *optional*, defaults to `1.0`):
|
||||
Discount factor.
|
||||
lam (`float`, *optional*, defaults to `0.95`):
|
||||
Lambda value for GAE.
|
||||
"""
|
||||
|
||||
exp_name: str = os.path.basename(__file__)[: -len(".py")]
|
||||
reward_model_path: str = "EleutherAI/pythia-160m"
|
||||
num_ppo_epochs: int = 4
|
||||
whiten_rewards: bool = False
|
||||
kl_coef: float = 0.05
|
||||
cliprange: float = 0.2
|
||||
vf_coef: float = 0.1
|
||||
cliprange_value: float = 0.2
|
||||
gamma: float = 1.0
|
||||
lam: float = 0.95
|
||||
# Define an alias for PPOv2Config that raises a warning
|
||||
class PPOv2Config(PPOConfig):
|
||||
def __init__(self, *args, **kwargs):
|
||||
warnings.warn(
|
||||
"`PPOv2Config` is deprecated and has been renamed to `PPOConfig`. Please use `PPOConfig` instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
@ -12,702 +12,17 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import math
|
||||
import os
|
||||
import textwrap
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import broadcast, gather_object
|
||||
from datasets import Dataset
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import (
|
||||
BaseImageProcessor,
|
||||
DataCollatorWithPadding,
|
||||
FeatureExtractionMixin,
|
||||
GenerationConfig,
|
||||
PreTrainedTokenizerBase,
|
||||
ProcessorMixin,
|
||||
Trainer,
|
||||
TrainerCallback,
|
||||
TrainerControl,
|
||||
is_wandb_available,
|
||||
)
|
||||
from transformers.integrations import get_reporting_integration_callbacks
|
||||
from transformers.trainer import DEFAULT_CALLBACKS, DEFAULT_PROGRESS_CALLBACK
|
||||
from transformers.trainer_callback import CallbackHandler, ExportableState, PrinterCallback
|
||||
|
||||
from ..core import masked_mean, masked_whiten
|
||||
from ..models.utils import unwrap_model_for_generation
|
||||
from ..trainer.utils import (
|
||||
OnlineTrainerState,
|
||||
batch_generation,
|
||||
disable_dropout_in_model,
|
||||
exact_div,
|
||||
first_true_indices,
|
||||
forward,
|
||||
get_reward,
|
||||
prepare_deepspeed,
|
||||
print_rich_table,
|
||||
truncate_response,
|
||||
)
|
||||
from .ppov2_config import PPOv2Config
|
||||
from .utils import generate_model_card
|
||||
from .ppo_trainer import PPOTrainer
|
||||
|
||||
|
||||
if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
|
||||
INVALID_LOGPROB = 1.0
|
||||
|
||||
|
||||
# taken from https://github.com/OpenLMLab/MOSS-RLHF/blob/40b91eb2f2b71b16919addede0341d2bef70825d/ppo/ppo_trainer.py#L29
|
||||
# we did this we can do a single `model = accelerator.prepare(model)`
|
||||
class PolicyAndValueWrapper(nn.Module):
|
||||
def __init__(self, policy, value_model) -> None:
|
||||
super().__init__()
|
||||
self.policy = policy
|
||||
self.value_model = value_model
|
||||
self.critic_backbone = getattr(value_model, value_model.base_model_prefix)
|
||||
|
||||
def forward(self, **kwargs):
|
||||
output = self.critic_backbone(
|
||||
**kwargs,
|
||||
# Define an alias for PPOv2Trainer that raises a warning
|
||||
class PPOv2Trainer(PPOTrainer):
|
||||
def __init__(self, *args, **kwargs):
|
||||
warnings.warn(
|
||||
"`PPOv2Trainer` is deprecated and has been renamed to `PPOTrainer`. Please use `PPOTrainer` instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
logits = self.value_model.score(output.hidden_states[-1])
|
||||
return self.policy(**kwargs), logits
|
||||
|
||||
|
||||
class PPOv2Trainer(Trainer):
|
||||
_tag_names = ["trl", "ppo"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PPOv2Config,
|
||||
processing_class: Optional[
|
||||
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
||||
],
|
||||
policy: nn.Module,
|
||||
ref_policy: nn.Module,
|
||||
reward_model: nn.Module,
|
||||
train_dataset: Dataset,
|
||||
value_model: Optional[nn.Module] = None,
|
||||
data_collator: Optional[DataCollatorWithPadding] = None,
|
||||
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
|
||||
# less commonly used
|
||||
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
||||
callbacks: Optional[List[TrainerCallback]] = None,
|
||||
) -> None:
|
||||
if ref_policy is policy:
|
||||
raise ValueError(
|
||||
"`policy` and `ref_policy` cannot be the same object. If you want `ref_policy` to be the "
|
||||
"same as `policy`, you must mass a copy of it, or `None` if you use peft."
|
||||
)
|
||||
|
||||
self.args = config
|
||||
args = config
|
||||
self.processing_class = processing_class
|
||||
self.policy = policy
|
||||
|
||||
self.policy.generation_config.eos_token_id = (
|
||||
None # disable `pad_token_id` and `eos_token_id` because we just want to
|
||||
)
|
||||
self.policy.generation_config.pad_token_id = None # generate tokens without truncation / padding
|
||||
|
||||
self.ref_policy = ref_policy
|
||||
self.reward_model = reward_model
|
||||
self.train_dataset = train_dataset
|
||||
self.train_dataset_len = len(train_dataset)
|
||||
self.value_model = value_model
|
||||
self.data_collator = data_collator
|
||||
self.eval_dataset = eval_dataset
|
||||
self.optimizer, self.lr_scheduler = optimizers
|
||||
|
||||
#########
|
||||
# calculate various batch sizes
|
||||
#########
|
||||
if args.total_episodes is None: # allow the users to define episodes in terms of epochs.
|
||||
args.total_episodes = int(args.num_train_epochs * self.train_dataset_len)
|
||||
accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps)
|
||||
self.accelerator = accelerator
|
||||
args.world_size = accelerator.num_processes
|
||||
args.local_batch_size = (
|
||||
args.per_device_train_batch_size * args.gradient_accumulation_steps * args.num_mini_batches
|
||||
)
|
||||
args.micro_batch_size = int(args.per_device_train_batch_size * args.world_size)
|
||||
args.batch_size = int(args.local_batch_size * args.world_size)
|
||||
args.mini_batch_size = exact_div(
|
||||
args.batch_size, args.num_mini_batches, "`batch_size` must be a multiple of `num_mini_batches`"
|
||||
)
|
||||
args.local_mini_batch_size = exact_div(
|
||||
args.local_batch_size, args.num_mini_batches, "`local_batch_size` must be a multiple of `num_mini_batches`"
|
||||
)
|
||||
if args.whiten_rewards:
|
||||
assert (
|
||||
args.local_mini_batch_size >= 8
|
||||
), f"Per-rank minibatch size {args.local_mini_batch_size} is insufficient for whitening"
|
||||
# `per_rank_rollout_batch_size` is our `args.local_batch_size`
|
||||
# `per_rank_minibatch_size` is our `args.local_mini_batch_size`
|
||||
args.num_total_batches = math.ceil(
|
||||
args.total_episodes / args.batch_size
|
||||
) # we may train for more than `total_episodes`
|
||||
time_tensor = torch.tensor(int(time.time()), device=accelerator.device)
|
||||
time_int = broadcast(time_tensor, 0).item() # avoid different timestamps across processes
|
||||
args.run_name = f"{args.exp_name}__{args.seed}__{time_int}"
|
||||
self.local_seed = args.seed + accelerator.process_index * 100003 # Prime
|
||||
if args.num_sample_generations > 0:
|
||||
self.sample_generations_freq = max(1, args.num_total_batches // args.num_sample_generations)
|
||||
self.local_dataloader_batch_size = args.local_batch_size
|
||||
|
||||
#########
|
||||
# setup model, optimizer, and others
|
||||
#########
|
||||
for module in [policy, ref_policy, value_model, reward_model]:
|
||||
disable_dropout_in_model(module)
|
||||
if args.stop_token and args.stop_token == "eos":
|
||||
args.stop_token_id = processing_class.eos_token_id
|
||||
self.model = PolicyAndValueWrapper(policy, value_model)
|
||||
self.model.config = policy.config # needed for pushing to hub
|
||||
self.create_optimizer_and_scheduler(
|
||||
num_training_steps=args.num_total_batches
|
||||
) # note that we are calling `self.lr_scheduler.step()` manually only at the batch level
|
||||
|
||||
#########
|
||||
### trainer specifics
|
||||
#########
|
||||
default_callbacks = DEFAULT_CALLBACKS + get_reporting_integration_callbacks(self.args.report_to)
|
||||
self.callbacks = default_callbacks if callbacks is None else default_callbacks + callbacks
|
||||
self.callback_handler = CallbackHandler(
|
||||
self.callbacks, self.model, self.processing_class, self.optimizer, self.lr_scheduler
|
||||
)
|
||||
self.add_callback(PrinterCallback if self.args.disable_tqdm else DEFAULT_PROGRESS_CALLBACK)
|
||||
self.control = TrainerControl()
|
||||
self.state = OnlineTrainerState(
|
||||
is_local_process_zero=self.is_local_process_zero(),
|
||||
is_world_process_zero=self.is_world_process_zero(),
|
||||
stateful_callbacks=[
|
||||
cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
|
||||
],
|
||||
)
|
||||
self.current_flos = 0
|
||||
self.hp_search_backend = None
|
||||
self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
|
||||
self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
|
||||
# Create distant repo and output directory if needed
|
||||
self.hub_model_id = None
|
||||
if self.args.push_to_hub:
|
||||
self.init_hf_repo()
|
||||
if self.args.should_save:
|
||||
os.makedirs(self.args.output_dir, exist_ok=True)
|
||||
|
||||
# Add tags for models that have been loaded with the correct transformers version
|
||||
if hasattr(self.model, "add_model_tags"):
|
||||
self.model.add_model_tags(self._tag_names)
|
||||
|
||||
#########
|
||||
### setup dataloader
|
||||
#########
|
||||
self.dataloader = DataLoader(
|
||||
self.train_dataset,
|
||||
batch_size=self.local_dataloader_batch_size,
|
||||
shuffle=True,
|
||||
collate_fn=DataCollatorWithPadding(self.processing_class),
|
||||
drop_last=True, # needed; otherwise the last batch will be of ragged shape
|
||||
)
|
||||
# sync random states for DataLoader(shuffle=True) before `accelerator.prepare`
|
||||
# see https://gist.github.com/vwxyzjn/2581bff1e48e185e0b85b6dfe1def79c
|
||||
torch.manual_seed(args.seed)
|
||||
self.model, self.optimizer, self.dataloader = accelerator.prepare(self.model, self.optimizer, self.dataloader)
|
||||
torch.manual_seed(self.local_seed) # reset the local seed again
|
||||
|
||||
self.eval_dataloader = DataLoader(
|
||||
self.eval_dataset,
|
||||
batch_size=args.per_device_eval_batch_size,
|
||||
collate_fn=DataCollatorWithPadding(self.processing_class),
|
||||
drop_last=True,
|
||||
) # no need to shuffle eval dataset
|
||||
self.eval_dataloader = accelerator.prepare(self.eval_dataloader)
|
||||
|
||||
if self.is_deepspeed_enabled:
|
||||
self.reward_model = prepare_deepspeed(
|
||||
self.reward_model, args.per_device_train_batch_size, args.fp16, args.bf16
|
||||
)
|
||||
self.ref_policy = prepare_deepspeed(
|
||||
self.ref_policy, args.per_device_train_batch_size, args.fp16, args.bf16
|
||||
)
|
||||
else:
|
||||
self.ref_policy = self.ref_policy.to(self.accelerator.device)
|
||||
self.reward_model = self.reward_model.to(self.accelerator.device)
|
||||
|
||||
def get_train_dataloader(self) -> DataLoader:
|
||||
return self.dataloader
|
||||
|
||||
def get_eval_dataloader(self) -> DataLoader:
|
||||
return self.eval_dataloader
|
||||
|
||||
def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
|
||||
backup_model = self.model
|
||||
self.model = self.model.policy # save only the policy
|
||||
|
||||
if self.is_deepspeed_enabled:
|
||||
backup_deepspeed = self.deepspeed
|
||||
self.deepspeed = self.model
|
||||
|
||||
super().save_model(output_dir, _internal_call)
|
||||
|
||||
self.model = backup_model
|
||||
|
||||
if self.is_deepspeed_enabled:
|
||||
self.deepspeed = backup_deepspeed
|
||||
|
||||
def train(self):
|
||||
args = self.args
|
||||
accelerator = self.accelerator
|
||||
optimizer = self.optimizer
|
||||
model = self.model
|
||||
ref_policy = self.ref_policy
|
||||
reward_model = self.reward_model
|
||||
processing_class = self.processing_class
|
||||
dataloader = self.dataloader
|
||||
device = accelerator.device
|
||||
|
||||
def repeat_generator():
|
||||
while True:
|
||||
yield from dataloader
|
||||
|
||||
iter_dataloader = iter(repeat_generator())
|
||||
generation_config = GenerationConfig(
|
||||
max_new_tokens=args.response_length,
|
||||
temperature=(args.temperature + 1e-7),
|
||||
top_k=0.0,
|
||||
top_p=1.0,
|
||||
do_sample=True,
|
||||
)
|
||||
|
||||
accelerator.print("===training policy===")
|
||||
start_time = time.time()
|
||||
stats_shape = (args.num_ppo_epochs, args.num_mini_batches, args.gradient_accumulation_steps)
|
||||
approxkl_stats = torch.zeros(stats_shape, device=device)
|
||||
pg_clipfrac_stats = torch.zeros(stats_shape, device=device)
|
||||
pg_loss_stats = torch.zeros(stats_shape, device=device)
|
||||
vf_loss_stats = torch.zeros(stats_shape, device=device)
|
||||
vf_clipfrac_stats = torch.zeros(stats_shape, device=device)
|
||||
entropy_stats = torch.zeros(stats_shape, device=device)
|
||||
ratio_stats = torch.zeros(stats_shape, device=device)
|
||||
model.train()
|
||||
|
||||
# trainer state initialization
|
||||
self.state.global_step = 0
|
||||
self.state.episode = 0
|
||||
self.state.max_steps = args.num_total_batches * args.num_mini_batches
|
||||
self.state.num_train_epochs = args.total_episodes / self.train_dataset_len
|
||||
# Compute absolute values for logging, eval, and save if given as ratio
|
||||
if args.logging_steps is not None:
|
||||
if args.logging_steps < 1:
|
||||
self.state.logging_steps = math.ceil(self.state.max_steps * args.logging_steps)
|
||||
else:
|
||||
self.state.logging_steps = args.logging_steps
|
||||
if args.eval_steps is not None:
|
||||
if args.eval_steps < 1:
|
||||
self.state.eval_steps = math.ceil(self.state.max_steps * args.eval_steps)
|
||||
else:
|
||||
self.state.eval_steps = args.eval_steps
|
||||
if args.save_steps is not None:
|
||||
if args.save_steps < 1:
|
||||
self.state.save_steps = math.ceil(self.state.max_steps * args.save_steps)
|
||||
else:
|
||||
self.state.save_steps = args.save_steps
|
||||
self.control = self.callback_handler.on_train_begin(args, self.state, self.control)
|
||||
|
||||
# backward compatibility
|
||||
if self.is_deepspeed_enabled:
|
||||
self.deepspeed = self.model
|
||||
self.model_wrapped = self.model
|
||||
|
||||
for update in range(1, args.num_total_batches + 1):
|
||||
self.state.episode += 1 * args.batch_size
|
||||
data = next(iter_dataloader)
|
||||
with torch.no_grad():
|
||||
queries = data["input_ids"].to(device)
|
||||
context_length = queries.shape[1]
|
||||
responses = []
|
||||
postprocessed_responses = []
|
||||
logprobs = []
|
||||
ref_logprobs = []
|
||||
scores = []
|
||||
sequence_lengths = []
|
||||
values = []
|
||||
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
|
||||
query_responses, logitss = batch_generation(
|
||||
unwrapped_model.policy,
|
||||
queries,
|
||||
args.local_rollout_forward_batch_size,
|
||||
processing_class.pad_token_id,
|
||||
generation_config,
|
||||
)
|
||||
|
||||
for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size):
|
||||
query = queries[i : i + args.local_rollout_forward_batch_size]
|
||||
query_response = query_responses[i : i + args.local_rollout_forward_batch_size]
|
||||
response = query_response[:, context_length:]
|
||||
logits = logitss[i : i + args.local_rollout_forward_batch_size]
|
||||
all_logprob = F.log_softmax(logits, dim=-1)
|
||||
logprob = torch.gather(all_logprob, 2, response.unsqueeze(-1)).squeeze(-1)
|
||||
del logits, all_logprob
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
ref_output = forward(ref_policy, query_response, processing_class.pad_token_id)
|
||||
ref_logits = ref_output.logits[:, context_length - 1 : -1]
|
||||
ref_logits /= args.temperature + 1e-7
|
||||
ref_all_logprob = F.log_softmax(ref_logits, dim=-1)
|
||||
ref_logprob = torch.gather(ref_all_logprob, 2, response.unsqueeze(-1)).squeeze(-1)
|
||||
del ref_output, ref_logits, ref_all_logprob
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Response Processing 1. truncate response after the first occurrence of `stop_token_id`
|
||||
postprocessed_response = response
|
||||
if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
|
||||
postprocessed_response = truncate_response(
|
||||
args.stop_token_id, processing_class.pad_token_id, response
|
||||
)
|
||||
|
||||
# Response Processing 2. run reward model on the truncated responses
|
||||
postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
|
||||
sequence_length = first_true_indices(postprocessed_response == processing_class.pad_token_id) - 1
|
||||
unwrapped_value_model = accelerator.unwrap_model(model).value_model
|
||||
full_value, _, _ = get_reward(
|
||||
unwrapped_value_model, query_response, processing_class.pad_token_id, context_length
|
||||
)
|
||||
value = full_value[:, context_length - 1 : -1].squeeze(-1)
|
||||
_, score, _ = get_reward(
|
||||
reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length
|
||||
)
|
||||
|
||||
responses.append(response)
|
||||
postprocessed_responses.append(postprocessed_response)
|
||||
logprobs.append(logprob)
|
||||
ref_logprobs.append(ref_logprob)
|
||||
sequence_lengths.append(sequence_length)
|
||||
scores.append(score)
|
||||
values.append(value)
|
||||
responses = torch.cat(responses, 0)
|
||||
postprocessed_responses = torch.cat(postprocessed_responses, 0)
|
||||
logprobs = torch.cat(logprobs, 0)
|
||||
ref_logprobs = torch.cat(ref_logprobs, 0)
|
||||
sequence_lengths = torch.cat(sequence_lengths, 0)
|
||||
scores = torch.cat(scores, 0)
|
||||
values = torch.cat(values, 0)
|
||||
del (logprob, ref_logprob, full_value, value, score, unwrapped_model)
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
# Response Processing 3. Filter completion. Ensure that the sample contains stop_token_id
|
||||
# Completions not passing that filter will receive a lower score.
|
||||
contain_eos_token = torch.any(postprocessed_responses == self.processing_class.eos_token_id, dim=-1)
|
||||
if self.args.missing_eos_penalty is not None:
|
||||
scores[~contain_eos_token] -= self.args.missing_eos_penalty
|
||||
# accelerator.print(f"{scores=}, {(contain_eos_token.sum() / len(contain_eos_token))=}")
|
||||
|
||||
# be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw
|
||||
response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1)
|
||||
padding_mask = response_idxs > sequence_lengths.unsqueeze(1)
|
||||
logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB)
|
||||
ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB)
|
||||
sequence_lengths_p1 = sequence_lengths + 1
|
||||
padding_mask_p1 = response_idxs > (sequence_lengths_p1.unsqueeze(1))
|
||||
values = torch.masked_fill(values, padding_mask_p1, 0)
|
||||
|
||||
# 4. compute rewards
|
||||
kl = logprobs - ref_logprobs
|
||||
non_score_reward = -args.kl_coef * kl
|
||||
rewards = non_score_reward.clone()
|
||||
actual_start = torch.arange(rewards.size(0), device=rewards.device)
|
||||
actual_end = torch.where(sequence_lengths_p1 < rewards.size(1), sequence_lengths_p1, sequence_lengths)
|
||||
rewards[[actual_start, actual_end]] += scores
|
||||
|
||||
# 5. whiten rewards
|
||||
if args.whiten_rewards:
|
||||
rewards = masked_whiten(rewards, mask=~padding_mask_p1, shift_mean=False)
|
||||
rewards = torch.masked_fill(rewards, padding_mask_p1, 0)
|
||||
|
||||
# 6. compute advantages and returns
|
||||
lastgaelam = 0
|
||||
advantages_reversed = []
|
||||
gen_length = responses.shape[1]
|
||||
for t in reversed(range(gen_length)):
|
||||
nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0
|
||||
delta = rewards[:, t] + args.gamma * nextvalues - values[:, t]
|
||||
lastgaelam = delta + args.gamma * args.lam * lastgaelam
|
||||
advantages_reversed.append(lastgaelam)
|
||||
advantages = torch.stack(advantages_reversed[::-1], axis=1)
|
||||
returns = advantages + values
|
||||
advantages = masked_whiten(advantages, ~padding_mask)
|
||||
advantages = torch.masked_fill(advantages, padding_mask, 0)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Do multiple epochs of PPO training, with a fresh random shuffle in each epoch
|
||||
for ppo_epoch_idx in range(args.num_ppo_epochs):
|
||||
b_inds = np.random.permutation(args.local_batch_size)
|
||||
minibatch_idx = 0
|
||||
for mini_batch_start in range(0, args.local_batch_size, args.local_mini_batch_size):
|
||||
mini_batch_end = mini_batch_start + args.local_mini_batch_size
|
||||
mini_batch_inds = b_inds[mini_batch_start:mini_batch_end]
|
||||
gradient_accumulation_idx = 0
|
||||
for micro_batch_start in range(0, args.local_mini_batch_size, args.per_device_train_batch_size):
|
||||
with accelerator.accumulate(model):
|
||||
micro_batch_end = micro_batch_start + args.per_device_train_batch_size
|
||||
micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end]
|
||||
mb_advantage = advantages[micro_batch_inds]
|
||||
mb_responses = responses[micro_batch_inds]
|
||||
mb_query_responses = query_responses[micro_batch_inds]
|
||||
mb_logprobs = logprobs[micro_batch_inds]
|
||||
mb_return = returns[micro_batch_inds]
|
||||
mb_values = values[micro_batch_inds]
|
||||
|
||||
output, vpred_temp = forward(model, mb_query_responses, processing_class.pad_token_id)
|
||||
logits = output.logits[:, context_length - 1 : -1]
|
||||
logits /= args.temperature + 1e-7
|
||||
new_all_logprobs = F.log_softmax(logits, dim=-1)
|
||||
new_logprobs = torch.gather(new_all_logprobs, 2, mb_responses.unsqueeze(-1)).squeeze(-1)
|
||||
new_logprobs = torch.masked_fill(
|
||||
new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB
|
||||
)
|
||||
vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1)
|
||||
vpred = torch.masked_fill(vpred, padding_mask_p1[micro_batch_inds], 0)
|
||||
vpredclipped = torch.clamp(
|
||||
vpred,
|
||||
mb_values - args.cliprange_value,
|
||||
mb_values + args.cliprange_value,
|
||||
)
|
||||
vf_losses1 = torch.square(vpred - mb_return)
|
||||
vf_losses2 = torch.square(vpredclipped - mb_return)
|
||||
vf_loss_max = torch.max(vf_losses1, vf_losses2)
|
||||
vf_loss = 0.5 * masked_mean(vf_loss_max, ~padding_mask_p1[micro_batch_inds])
|
||||
vf_clipfrac = masked_mean(
|
||||
(vf_losses2 > vf_losses1).float(), ~padding_mask_p1[micro_batch_inds]
|
||||
)
|
||||
logprobs_diff = new_logprobs - mb_logprobs
|
||||
ratio = torch.exp(logprobs_diff)
|
||||
pg_losses = -mb_advantage * ratio
|
||||
pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange)
|
||||
pg_loss_max = torch.max(pg_losses, pg_losses2)
|
||||
pg_loss = masked_mean(pg_loss_max, ~padding_mask[micro_batch_inds])
|
||||
loss = pg_loss + args.vf_coef * vf_loss
|
||||
accelerator.backward(loss)
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
with torch.no_grad():
|
||||
pg_clipfrac = masked_mean(
|
||||
(pg_losses2 > pg_losses).float(), ~padding_mask[micro_batch_inds]
|
||||
)
|
||||
prob_dist = torch.nn.functional.softmax(logits, dim=-1)
|
||||
entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1)
|
||||
approxkl = 0.5 * (logprobs_diff**2).mean()
|
||||
approxkl_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl
|
||||
pg_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = (
|
||||
pg_clipfrac
|
||||
)
|
||||
pg_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss
|
||||
vf_loss_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss
|
||||
vf_clipfrac_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = (
|
||||
vf_clipfrac
|
||||
)
|
||||
entropy_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean()
|
||||
ratio_stats[ppo_epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean()
|
||||
gradient_accumulation_idx += 1
|
||||
minibatch_idx += 1
|
||||
# del everything and empty cache
|
||||
# fmt: off
|
||||
del (
|
||||
output, vpred_temp, logits, new_all_logprobs, new_logprobs, vpred, vpredclipped,
|
||||
vf_losses1, vf_losses2, vf_loss, vf_clipfrac, logprobs_diff, ratio, pg_losses, pg_losses2, pg_loss_max,
|
||||
pg_loss, loss, pg_clipfrac, prob_dist, entropy, approxkl, mb_return,
|
||||
mb_advantage, mb_values, mb_responses, mb_query_responses, mb_logprobs,
|
||||
)
|
||||
# fmt: on
|
||||
torch.cuda.empty_cache()
|
||||
with torch.no_grad():
|
||||
mean_kl = kl.sum(1).mean()
|
||||
mean_entropy = (-logprobs).sum(1).mean()
|
||||
mean_non_score_reward = non_score_reward.sum(1).mean()
|
||||
rlhf_reward = mean_non_score_reward + scores.mean()
|
||||
eps = int(self.state.episode / (time.time() - start_time))
|
||||
metrics = {}
|
||||
metrics["eps"] = eps
|
||||
metrics["objective/kl"] = self.accelerator.gather(mean_kl).mean().item()
|
||||
metrics["objective/entropy"] = self.accelerator.gather(mean_entropy).mean().item()
|
||||
metrics["objective/non_score_reward"] = self.accelerator.gather(mean_non_score_reward).mean().item()
|
||||
metrics["objective/rlhf_reward"] = self.accelerator.gather(rlhf_reward).mean().item()
|
||||
metrics["objective/scores"] = self.accelerator.gather(scores.mean()).mean().item()
|
||||
metrics["policy/approxkl_avg"] = self.accelerator.gather(approxkl_stats).mean().item()
|
||||
metrics["policy/clipfrac_avg"] = self.accelerator.gather(pg_clipfrac_stats).mean().item()
|
||||
metrics["loss/policy_avg"] = self.accelerator.gather(pg_loss_stats).mean().item()
|
||||
metrics["loss/value_avg"] = self.accelerator.gather(vf_loss_stats).mean().item()
|
||||
metrics["val/clipfrac_avg"] = self.accelerator.gather(vf_clipfrac_stats).mean().item()
|
||||
metrics["policy/entropy_avg"] = self.accelerator.gather(entropy_stats).mean().item()
|
||||
metrics["val/ratio"] = self.accelerator.gather(ratio_stats).mean().item()
|
||||
metrics["val/ratio_var"] = self.accelerator.gather(ratio_stats).var().item()
|
||||
metrics["val/num_eos_tokens"] = (responses == processing_class.eos_token_id).sum().item()
|
||||
metrics["lr"] = self.lr_scheduler.get_last_lr()[0]
|
||||
metrics["episode"] = self.state.episode
|
||||
self.state.epoch = self.state.episode / self.train_dataset_len # used by self.log
|
||||
self.state.global_step += 1
|
||||
self.log(metrics)
|
||||
|
||||
self.lr_scheduler.step()
|
||||
self.control = self.callback_handler.on_step_end(args, self.state, self.control)
|
||||
if self.control.should_save:
|
||||
self._save_checkpoint(model, trial=None, metrics=metrics)
|
||||
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
|
||||
del kl, mean_kl, mean_entropy, mean_non_score_reward, scores, metrics, non_score_reward
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
if args.num_sample_generations > 0 and (update - 1) % self.sample_generations_freq == 0:
|
||||
self.generate_completions(sampling=True)
|
||||
torch.cuda.empty_cache()
|
||||
del (
|
||||
query_responses,
|
||||
responses,
|
||||
postprocessed_responses,
|
||||
logprobs,
|
||||
ref_logprobs,
|
||||
values,
|
||||
sequence_lengths,
|
||||
contain_eos_token,
|
||||
sequence_lengths_p1,
|
||||
response_idxs,
|
||||
padding_mask,
|
||||
padding_mask_p1,
|
||||
rewards,
|
||||
actual_start,
|
||||
actual_end,
|
||||
advantages,
|
||||
returns,
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# HF trainer specifics
|
||||
self.control = self.callback_handler.on_train_end(args, self.state, self.control)
|
||||
if self.control.should_save:
|
||||
self._save_checkpoint(model, trial=None, metrics=None)
|
||||
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
|
||||
|
||||
def generate_completions(self, sampling: bool = False):
|
||||
args = self.args
|
||||
processing_class = self.processing_class
|
||||
generation_config = GenerationConfig(
|
||||
max_new_tokens=self.args.response_length,
|
||||
temperature=(0.01 + 1e-7),
|
||||
top_k=0.0,
|
||||
top_p=1.0,
|
||||
do_sample=True,
|
||||
)
|
||||
|
||||
table = defaultdict(list)
|
||||
with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
|
||||
for batch in self.eval_dataloader:
|
||||
query = batch["input_ids"]
|
||||
with torch.no_grad():
|
||||
context_length = query.shape[1]
|
||||
query_response, _ = batch_generation(
|
||||
unwrapped_model.policy,
|
||||
query,
|
||||
query.shape[0],
|
||||
processing_class.pad_token_id,
|
||||
generation_config,
|
||||
)
|
||||
response = query_response[:, context_length:]
|
||||
postprocessed_response = response
|
||||
if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0
|
||||
postprocessed_response = truncate_response(
|
||||
args.stop_token_id, processing_class.pad_token_id, response
|
||||
)
|
||||
table["query"].extend(
|
||||
gather_object(processing_class.batch_decode(query, skip_special_tokens=True))
|
||||
)
|
||||
table["model response"].extend(
|
||||
gather_object(processing_class.batch_decode(postprocessed_response))
|
||||
)
|
||||
|
||||
postprocessed_query_response = torch.cat((query, postprocessed_response), 1)
|
||||
_, score, _ = get_reward(
|
||||
self.reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length
|
||||
)
|
||||
table["score"].extend(self.accelerator.gather(score).float().cpu().numpy())
|
||||
|
||||
if sampling:
|
||||
break
|
||||
df = pd.DataFrame(table)
|
||||
|
||||
if self.accelerator.is_main_process:
|
||||
print_rich_table(df.iloc[0 : 0 + 5])
|
||||
if "wandb" in args.report_to:
|
||||
import wandb
|
||||
|
||||
if wandb.run is not None:
|
||||
wandb.log({"completions": wandb.Table(dataframe=df)})
|
||||
|
||||
def create_model_card(
|
||||
self,
|
||||
model_name: Optional[str] = None,
|
||||
dataset_name: Optional[str] = None,
|
||||
tags: Union[str, List[str], None] = None,
|
||||
):
|
||||
"""
|
||||
Creates a draft of a model card using the information available to the `Trainer`.
|
||||
|
||||
Args:
|
||||
model_name (`str`, *optional*, defaults to `None`):
|
||||
The name of the model.
|
||||
dataset_name (`str`, *optional*, defaults to `None`):
|
||||
The name of the dataset used for training.
|
||||
tags (`str`, `List[str]` or `None`, *optional*, defaults to `None`):
|
||||
Tags to be associated with the model card.
|
||||
"""
|
||||
if not self.is_world_process_zero():
|
||||
return
|
||||
|
||||
if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
|
||||
base_model = self.model.config._name_or_path
|
||||
else:
|
||||
base_model = None
|
||||
|
||||
tags = tags or []
|
||||
if isinstance(tags, str):
|
||||
tags = [tags]
|
||||
|
||||
if hasattr(self.model.config, "unsloth_version"):
|
||||
tags.append("unsloth")
|
||||
|
||||
citation = textwrap.dedent("""\
|
||||
@article{mziegler2019fine-tuning,
|
||||
title = {{Fine-Tuning Language Models from Human Preferences}},
|
||||
author = {Daniel M. Ziegler and Nisan Stiennon and Jeffrey Wu and Tom B. Brown and Alec Radford and Dario Amodei and Paul F. Christiano and Geoffrey Irving},
|
||||
year = 2019,
|
||||
eprint = {arXiv:1909.08593}
|
||||
}""")
|
||||
|
||||
model_card = generate_model_card(
|
||||
base_model=base_model,
|
||||
model_name=model_name,
|
||||
hub_model_id=self.hub_model_id,
|
||||
dataset_name=dataset_name,
|
||||
tags=tags,
|
||||
wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
|
||||
trainer_name="PPO",
|
||||
trainer_citation=citation,
|
||||
paper_title="Fine-Tuning Language Models from Human Preferences",
|
||||
paper_id="1909.08593",
|
||||
)
|
||||
|
||||
model_card.save(os.path.join(self.args.output_dir, "README.md"))
|
||||
super().__init__(*args, **kwargs)
|
||||
|
Reference in New Issue
Block a user