mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
Compare commits
84 Commits
Author | SHA1 | Date | |
---|---|---|---|
a5788ac99b | |||
3bbe7e0407 | |||
edf60e826b | |||
5d1deb1445 | |||
476c4b8dc0 | |||
e823458a6a | |||
1c0d8bca15 | |||
363369a717 | |||
aba4df02c1 | |||
98226473e4 | |||
87f4c70e60 | |||
995f1174da | |||
143e11123d | |||
346c99d222 | |||
087fe544b0 | |||
ebbd37ba99 | |||
e667550a5a | |||
57aebe9c36 | |||
85f5fd220d | |||
4dca169404 | |||
f35b68a301 | |||
5cf863576a | |||
9a28b3fd05 | |||
4f8057ad23 | |||
ab0d11d815 | |||
c674c66a45 | |||
45da5df53e | |||
04fd8d9400 | |||
bf2aed3876 | |||
0ee349dcd4 | |||
7ff6206510 | |||
e4b20ecbc4 | |||
6c2f829bb7 | |||
c4f0f41935 | |||
dc6a934269 | |||
9ce7ac6925 | |||
99553c19ae | |||
2ce8e45bb2 | |||
d1df79f83c | |||
d10f7663b0 | |||
423991c204 | |||
988d4c4e1a | |||
8534f0edf8 | |||
5095e7f948 | |||
9fcf61d706 | |||
66b043a910 | |||
f2c71771cc | |||
631c33cbb3 | |||
3f7ff60528 | |||
1705aebeba | |||
4e622a9033 | |||
eb2d5b2972 | |||
f976c6d234 | |||
abc7301bab | |||
6cfa5cfc81 | |||
a2aa0f0b09 | |||
304e208f77 | |||
4fe8b027f6 | |||
fb6ebb1e11 | |||
66078c7c01 | |||
58c0888996 | |||
486e7a4071 | |||
7630f877f9 | |||
4d862da181 | |||
22b4f548f4 | |||
4219cbfedc | |||
3bd02380c7 | |||
067db7553a | |||
93e85ed808 | |||
14e0d78807 | |||
b32656f726 | |||
9399bc113b | |||
11f122ad49 | |||
009c9a610b | |||
7712d42f8c | |||
7c2213b9e5 | |||
ddeebce176 | |||
cf68d871cf | |||
2a2676e7ec | |||
ca90cba351 | |||
4f97fb4a74 | |||
a46cd84a64 | |||
1f56bffdf8 | |||
1bfe0b8fcb |
55
.github/workflows/docker-build.yml
vendored
55
.github/workflows/docker-build.yml
vendored
@ -10,6 +10,9 @@ concurrency:
|
||||
group: docker-image-builds
|
||||
cancel-in-progress: false
|
||||
|
||||
env:
|
||||
CI_SLACK_CHANNEL: ${{ secrets.CI_DOCKER_CHANNEL }}
|
||||
|
||||
jobs:
|
||||
trl-latest:
|
||||
name: "Latest TRL GPU"
|
||||
@ -42,6 +45,31 @@ jobs:
|
||||
push: true
|
||||
tags: huggingface/trl-latest-gpu
|
||||
|
||||
- name: Post to a Slack channel
|
||||
id: slack
|
||||
#uses: slackapi/slack-github-action@v1.25.0
|
||||
uses: slackapi/slack-github-action@6c661ce58804a1a20f6dc5fbee7f0381b469e001
|
||||
with:
|
||||
# Slack channel id, channel name, or user id to post message.
|
||||
# See also: https://api.slack.com/methods/chat.postMessage#channels
|
||||
channel-id: ${{ env.CI_SLACK_CHANNEL }}
|
||||
# For posting a rich message using Block Kit
|
||||
payload: |
|
||||
{
|
||||
"text": "trl-latest-gpu Docker Image build result: ${{ job.status }}\n${{ github.event.pull_request.html_url || github.event.head_commit.url }}",
|
||||
"blocks": [
|
||||
{
|
||||
"type": "section",
|
||||
"text": {
|
||||
"type": "mrkdwn",
|
||||
"text": "trl-latest-gpu Docker Image build result: ${{ job.status }}\n${{ github.event.pull_request.html_url || github.event.head_commit.url }}"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
env:
|
||||
SLACK_BOT_TOKEN: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
|
||||
|
||||
trl-source:
|
||||
name: "Latest TRL + HF ecosystem from source"
|
||||
runs-on: ubuntu-latest
|
||||
@ -71,4 +99,29 @@ jobs:
|
||||
with:
|
||||
context: ./docker/trl-source-gpu
|
||||
push: true
|
||||
tags: huggingface/trl-source-gpu
|
||||
tags: huggingface/trl-source-gpu
|
||||
|
||||
- name: Post to a Slack channel
|
||||
id: slack
|
||||
#uses: slackapi/slack-github-action@v1.25.0
|
||||
uses: slackapi/slack-github-action@6c661ce58804a1a20f6dc5fbee7f0381b469e001
|
||||
with:
|
||||
# Slack channel id, channel name, or user id to post message.
|
||||
# See also: https://api.slack.com/methods/chat.postMessage#channels
|
||||
channel-id: ${{ env.CI_SLACK_CHANNEL }}
|
||||
# For posting a rich message using Block Kit
|
||||
payload: |
|
||||
{
|
||||
"text": "trl-source-gpu Docker Image build result: ${{ job.status }}\n${{ github.event.pull_request.html_url || github.event.head_commit.url }}",
|
||||
"blocks": [
|
||||
{
|
||||
"type": "section",
|
||||
"text": {
|
||||
"type": "mrkdwn",
|
||||
"text": "trl-source-gpu Docker Image build result: ${{ job.status }}\n${{ github.event.pull_request.html_url || github.event.head_commit.url }}"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
env:
|
||||
SLACK_BOT_TOKEN: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
|
33
.github/workflows/tests-main.yml
vendored
33
.github/workflows/tests-main.yml
vendored
@ -4,12 +4,16 @@ on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
|
||||
env:
|
||||
CI_SLACK_CHANNEL: ${{ secrets.CI_PUSH_MAIN_CHANNEL }}
|
||||
|
||||
jobs:
|
||||
tests:
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ['3.9', '3.10', '3.11']
|
||||
os: ['ubuntu-latest', 'windows-latest']
|
||||
fail-fast: false
|
||||
runs-on: ${{ matrix.os }}
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
@ -28,7 +32,32 @@ jobs:
|
||||
pip install -U git+https://github.com/huggingface/peft.git
|
||||
pip install -U git+https://github.com/huggingface/transformers.git
|
||||
# cpu version of pytorch
|
||||
pip install -e ".[test, diffusers]"
|
||||
pip install ".[test, diffusers]"
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
make test
|
||||
make test
|
||||
- name: Post to a Slack channel
|
||||
if: always()
|
||||
id: slack
|
||||
#uses: slackapi/slack-github-action@v1.25.0
|
||||
uses: slackapi/slack-github-action@6c661ce58804a1a20f6dc5fbee7f0381b469e001
|
||||
with:
|
||||
# Slack channel id, channel name, or user id to post message.
|
||||
# See also: https://api.slack.com/methods/chat.postMessage#channels
|
||||
channel-id: ${{ env.CI_SLACK_CHANNEL }}
|
||||
# For posting a rich message using Block Kit
|
||||
payload: |
|
||||
{
|
||||
"text": "TRL CI on transformers/PEFT main: ${{ job.status }}\n${{ github.event.pull_request.html_url || github.event.head_commit.url }}",
|
||||
"blocks": [
|
||||
{
|
||||
"type": "section",
|
||||
"text": {
|
||||
"type": "mrkdwn",
|
||||
"text": "TRL CI on transformers/PEFT main: ${{ job.status }}\n${{ github.event.pull_request.html_url || github.event.head_commit.url }}"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
env:
|
||||
SLACK_BOT_TOKEN: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
|
||||
|
2
.github/workflows/tests.yml
vendored
2
.github/workflows/tests.yml
vendored
@ -54,7 +54,7 @@ jobs:
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
# cpu version of pytorch
|
||||
pip install -e ".[test, peft, diffusers]"
|
||||
pip install ".[test, peft, diffusers]"
|
||||
- name: Test with pytest
|
||||
run: |
|
||||
make test
|
||||
|
5
.gitignore
vendored
5
.gitignore
vendored
@ -143,4 +143,7 @@ checklink/cookies.txt
|
||||
# wandb files
|
||||
nbs/wandb/
|
||||
examples/notebooks/wandb/
|
||||
wandb/
|
||||
wandb/
|
||||
|
||||
# cli scripts that are symlinked from `examples/scripts`
|
||||
trl/commands/scripts/
|
@ -5,7 +5,7 @@
|
||||
Before you start contributing make sure you installed all the dev tools:
|
||||
|
||||
```bash
|
||||
pip install -e ".[dev]"
|
||||
make dev
|
||||
```
|
||||
|
||||
## Did you find a bug?
|
||||
|
@ -2,4 +2,4 @@ include settings.ini
|
||||
include LICENSE
|
||||
include CONTRIBUTING.md
|
||||
include README.md
|
||||
recursive-exclude * __pycache__
|
||||
recursive-exclude * __pycache__
|
6
Makefile
6
Makefile
@ -5,6 +5,12 @@ check_dirs := examples tests trl
|
||||
ACCELERATE_CONFIG_PATH = `pwd`/examples/accelerate_configs
|
||||
COMMAND_FILES_PATH = `pwd`/commands
|
||||
|
||||
|
||||
dev:
|
||||
[ -L "$(pwd)/trl/commands/scripts" ] && unlink "$(pwd)/trl/commands/scripts" || true
|
||||
pip install -e ".[dev]"
|
||||
ln -s `pwd`/examples/scripts/ `pwd`/trl/commands
|
||||
|
||||
test:
|
||||
python -m pytest -n auto --dist=loadfile -s -v ./tests/
|
||||
|
||||
|
130
README.md
130
README.md
@ -3,7 +3,7 @@
|
||||
</div>
|
||||
|
||||
# TRL - Transformer Reinforcement Learning
|
||||
> Full stack transformer language models with reinforcement learning.
|
||||
> Full stack library to fine-tune and align large language models.
|
||||
|
||||
<p align="center">
|
||||
<a href="https://github.com/huggingface/trl/blob/main/LICENSE">
|
||||
@ -20,61 +20,73 @@
|
||||
|
||||
## What is it?
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/TRL-readme.png">
|
||||
</div>
|
||||
The `trl` library is a full stack tool to fine-tune and align transformer language and diffusion models using methods such as Supervised Fine-tuning step (SFT), Reward Modeling (RM) and the Proximal Policy Optimization (PPO) as well as Direct Preference Optimization (DPO).
|
||||
|
||||
`trl` is a full stack library where we provide a set of tools to train transformer language models and stable diffusion models with Reinforcement Learning, from the Supervised Fine-tuning step (SFT), Reward Modeling step (RM) to the Proximal Policy Optimization (PPO) step. The library is built on top of the [`transformers`](https://github.com/huggingface/transformers) library by 🤗 Hugging Face. Therefore, pre-trained language models can be directly loaded via `transformers`. At this point, most of decoder architectures and encoder-decoder architectures are supported. Refer to the documentation or the `examples/` folder for example code snippets and how to run these tools.
|
||||
|
||||
**Highlights:**
|
||||
|
||||
- [`SFTTrainer`](https://huggingface.co/docs/trl/sft_trainer): A light and friendly wrapper around `transformers` Trainer to easily fine-tune language models or adapters on a custom dataset.
|
||||
- [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer): A light wrapper around `transformers` Trainer to easily fine-tune language models for human preferences (Reward Modeling).
|
||||
- [`PPOTrainer`](https://huggingface.co/docs/trl/trainer#trl.PPOTrainer): A PPO trainer for language models that just needs (query, response, reward) triplets to optimise the language model.
|
||||
- [`AutoModelForCausalLMWithValueHead`](https://huggingface.co/docs/trl/models#trl.AutoModelForCausalLMWithValueHead) & [`AutoModelForSeq2SeqLMWithValueHead`](https://huggingface.co/docs/trl/models#trl.AutoModelForSeq2SeqLMWithValueHead): A transformer model with an additional scalar output for each token which can be used as a value function in reinforcement learning.
|
||||
- [Examples](https://github.com/huggingface/trl/tree/main/examples): Train GPT2 to generate positive movie reviews with a BERT sentiment classifier, full RLHF using adapters only, train GPT-j to be less toxic, [Stack-Llama example](https://huggingface.co/blog/stackllama), etc.
|
||||
|
||||
## How PPO works
|
||||
Fine-tuning a language model via PPO consists of roughly three steps:
|
||||
|
||||
1. **Rollout**: The language model generates a response or continuation based on query which could be the start of a sentence.
|
||||
2. **Evaluation**: The query and response are evaluated with a function, model, human feedback or some combination of them. The important thing is that this process should yield a scalar value for each query/response pair.
|
||||
3. **Optimization**: This is the most complex part. In the optimisation step the query/response pairs are used to calculate the log-probabilities of the tokens in the sequences. This is done with the model that is trained and a reference model, which is usually the pre-trained model before fine-tuning. The KL-divergence between the two outputs is used as an additional reward signal to make sure the generated responses don't deviate too far from the reference language model. The active language model is then trained with PPO.
|
||||
|
||||
This process is illustrated in the sketch below:
|
||||
The library is built on top of the [`transformers`](https://github.com/huggingface/transformers) library and thus allows to use any model architecture available there.
|
||||
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl_overview.png" width="800">
|
||||
<p style="text-align: center;"> <b>Figure:</b> Sketch of the workflow. </p>
|
||||
</div>
|
||||
## Highlights
|
||||
|
||||
- **`Efficient and scalable`**:
|
||||
- [`accelerate`](https://github.com/huggingface/accelerate) is the backbone of `trl` which allows to scale model training from a single GPU to a large scale multi-node cluster with methods such as DDP and DeepSpeed.
|
||||
- [`PEFT`](https://github.com/huggingface/peft) is fully integrated and allows to train even the largest models on modest hardware with quantisation and methods such as LoRA or QLoRA.
|
||||
- [`unsloth`](https://github.com/unslothai/unsloth) is also integrated and allows to significantly speed up training with dedicated kernels.
|
||||
- **`CLI`**: With the [CLI](https://huggingface.co/docs/trl/clis) you can fine-tune and chat with LLMs without writing any code using a single command and a flexible config system.
|
||||
- **`Trainers`**: The Trainer classes are an abstraction to apply many fine-tuning methods with ease such as the [`SFTTrainer`](https://huggingface.co/docs/trl/sft_trainer), [`DPOTrainer`](https://huggingface.co/docs/trl/trainer#trl.DPOTrainer), [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer), [`PPOTrainer`](https://huggingface.co/docs/trl/trainer#trl.PPOTrainer), [`CPOTrainer`](https://huggingface.co/docs/trl/trainer#trl.CPOTrainer), and [`ORPOTrainer`](https://huggingface.co/docs/trl/trainer#trl.ORPOTrainer).
|
||||
- **`AutoModels`**: The [`AutoModelForCausalLMWithValueHead`](https://huggingface.co/docs/trl/models#trl.AutoModelForCausalLMWithValueHead) & [`AutoModelForSeq2SeqLMWithValueHead`](https://huggingface.co/docs/trl/models#trl.AutoModelForSeq2SeqLMWithValueHead) classes add an additional value head to the model which allows to train them with RL algorithms such as PPO.
|
||||
- **`Examples`**: Train GPT2 to generate positive movie reviews with a BERT sentiment classifier, full RLHF using adapters only, train GPT-j to be less toxic, [StackLlama example](https://huggingface.co/blog/stackllama), etc. following the [examples](https://github.com/huggingface/trl/tree/main/examples).
|
||||
|
||||
## Installation
|
||||
|
||||
### Python package
|
||||
Install the library with pip:
|
||||
Install the library with `pip`:
|
||||
```bash
|
||||
pip install trl
|
||||
```
|
||||
|
||||
### From source
|
||||
If you want to run the examples in the repository a few additional libraries are required. Clone the repository and install it with pip:
|
||||
If you want to use the latest features before an official release you can install from source:
|
||||
```bash
|
||||
git clone https://github.com/huggingface/trl.git
|
||||
cd trl/
|
||||
pip install .
|
||||
pip install git+https://github.com/huggingface/trl.git
|
||||
```
|
||||
|
||||
If you wish to develop TRL, you should install in editable mode:
|
||||
### Repository
|
||||
If you want to use the examples you can clone the repository with the following command:
|
||||
```bash
|
||||
pip install -e .
|
||||
git clone https://github.com/huggingface/trl.git
|
||||
```
|
||||
|
||||
## Command Line Interface (CLI)
|
||||
|
||||
You can use TRL Command Line Interface (CLI) to quickly get started with Supervised Fine-tuning (SFT), Direct Preference Optimization (DPO) and test your aligned model with the chat CLI:
|
||||
|
||||
**SFT:**
|
||||
|
||||
```bash
|
||||
trl sft --model_name_or_path facebook/opt-125m --dataset_name imdb --output_dir opt-sft-imdb
|
||||
```
|
||||
|
||||
**DPO:**
|
||||
|
||||
```bash
|
||||
trl dpo --model_name_or_path facebook/opt-125m --dataset_name trl-internal-testing/hh-rlhf-trl-style --output_dir opt-sft-hh-rlhf
|
||||
```
|
||||
|
||||
**Chat:**
|
||||
|
||||
```bash
|
||||
trl chat --model_name_or_path Qwen/Qwen1.5-0.5B-Chat
|
||||
```
|
||||
|
||||
Read more about CLI in the [relevant documentation section](https://huggingface.co/docs/trl/main/en/clis) or use `--help` for more details.
|
||||
|
||||
## How to use
|
||||
|
||||
For more flexibility and control over the training, you can use the dedicated trainer classes to fine-tune the model in Python.
|
||||
|
||||
### `SFTTrainer`
|
||||
|
||||
This is a basic example on how to use the `SFTTrainer` from the library. The `SFTTrainer` is a light wrapper around the `transformers` Trainer to easily fine-tune language models or adapters on a custom dataset.
|
||||
This is a basic example of how to use the `SFTTrainer` from the library. The `SFTTrainer` is a light wrapper around the `transformers` Trainer to easily fine-tune language models or adapters on a custom dataset.
|
||||
|
||||
```python
|
||||
# imports
|
||||
@ -98,7 +110,7 @@ trainer.train()
|
||||
|
||||
### `RewardTrainer`
|
||||
|
||||
This is a basic example on how to use the `RewardTrainer` from the library. The `RewardTrainer` is a wrapper around the `transformers` Trainer to easily fine-tune reward models or adapters on a custom preference dataset.
|
||||
This is a basic example of how to use the `RewardTrainer` from the library. The `RewardTrainer` is a wrapper around the `transformers` Trainer to easily fine-tune reward models or adapters on a custom preference dataset.
|
||||
|
||||
```python
|
||||
# imports
|
||||
@ -124,7 +136,7 @@ trainer.train()
|
||||
|
||||
### `PPOTrainer`
|
||||
|
||||
This is a basic example on how to use the `PPOTrainer` from the library. Based on a query the language model creates a response which is then evaluated. The evaluation could be a human in the loop or another model's output.
|
||||
This is a basic example of how to use the `PPOTrainer` from the library. Based on a query the language model creates a response which is then evaluated. The evaluation could be a human in the loop or another model's output.
|
||||
|
||||
```python
|
||||
# imports
|
||||
@ -138,11 +150,10 @@ model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
|
||||
model_ref = create_reference_model(model)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained('gpt2')
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
# initialize trainer
|
||||
ppo_config = PPOConfig(
|
||||
batch_size=1,
|
||||
)
|
||||
ppo_config = PPOConfig(batch_size=1, mini_batch_size=1)
|
||||
|
||||
# encode a query
|
||||
query_txt = "This morning I went to the "
|
||||
@ -162,13 +173,50 @@ reward = [torch.tensor(1.0)]
|
||||
train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward)
|
||||
```
|
||||
|
||||
### `DPOTrainer`
|
||||
|
||||
`DPOTrainer` is a trainer that uses [Direct Preference Optimization algorithm](https://arxiv.org/abs/2305.18290). This is a basic example of how to use the `DPOTrainer` from the library. The `DPOTrainer` is a wrapper around the `transformers` Trainer to easily fine-tune reward models or adapters on a custom preference dataset.
|
||||
|
||||
```python
|
||||
# imports
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from trl import DPOTrainer
|
||||
|
||||
# load model and dataset - dataset needs to be in a specific format
|
||||
model = AutoModelForCausalLM.from_pretrained("gpt2")
|
||||
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
|
||||
...
|
||||
|
||||
# load trainer
|
||||
trainer = DPOTrainer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=dataset,
|
||||
)
|
||||
|
||||
# train
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
## Development
|
||||
|
||||
If you want to contribute to `trl` or customizing it to your needs make sure to read the [contribution guide](https://github.com/huggingface/trl/blob/main/CONTRIBUTING.md) and make sure you make a dev install:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/trl.git
|
||||
cd trl/
|
||||
make dev
|
||||
```
|
||||
|
||||
## References
|
||||
|
||||
### Proximal Policy Optimisation
|
||||
The PPO implementation largely follows the structure introduced in the paper **"Fine-Tuning Language Models from Human Preferences"** by D. Ziegler et al. \[[paper](https://arxiv.org/pdf/1909.08593.pdf), [code](https://github.com/openai/lm-human-preferences)].
|
||||
|
||||
### Language models
|
||||
The language models utilize the `transformers` library by 🤗 Hugging Face. Currently, `trl` only supports `transformers` models **for text**.
|
||||
### Direct Preference Optimization
|
||||
DPO is based on the original implementation of **"Direct Preference Optimization: Your Language Model is Secretly a Reward Model"** by E. Mitchell et al. \[[paper](https://arxiv.org/pdf/2305.18290.pdf), [code](https://github.com/eric-mitchell/direct-preference-optimization)]
|
||||
|
||||
|
||||
## Citation
|
||||
|
||||
|
@ -3,6 +3,7 @@
|
||||
# but defaults to QLoRA + PEFT
|
||||
OUTPUT_DIR="test_dpo/"
|
||||
MODEL_NAME="HuggingFaceM4/tiny-random-LlamaForCausalLM"
|
||||
DATASET_NAME="trl-internal-testing/hh-rlhf-trl-style"
|
||||
MAX_STEPS=5
|
||||
BATCH_SIZE=2
|
||||
SEQ_LEN=128
|
||||
@ -36,6 +37,7 @@ accelerate launch $EXTRA_ACCELERATE_ARGS \
|
||||
--mixed_precision 'fp16' \
|
||||
`pwd`/examples/scripts/dpo.py \
|
||||
--model_name_or_path $MODEL_NAME \
|
||||
--dataset_name $DATASET_NAME \
|
||||
--output_dir $OUTPUT_DIR \
|
||||
--max_steps $MAX_STEPS \
|
||||
--per_device_train_batch_size $BATCH_SIZE \
|
||||
|
@ -41,6 +41,7 @@ accelerate launch $EXTRA_ACCELERATE_ARGS \
|
||||
--dataset_name $DATASET_NAME \
|
||||
--output_dir $OUTPUT_DIR \
|
||||
--max_steps $MAX_STEPS \
|
||||
--dataset_text_field 'text' \
|
||||
--per_device_train_batch_size $BATCH_SIZE \
|
||||
--max_seq_length $SEQ_LEN \
|
||||
$EXTRA_TRAINING_ARGS
|
||||
@ -56,4 +57,4 @@ echo "Starting program..."
|
||||
echo "Operation Failed!"
|
||||
exit 1
|
||||
}
|
||||
exit 0
|
||||
exit 0
|
||||
|
@ -5,6 +5,8 @@
|
||||
title: Quickstart
|
||||
- local: installation
|
||||
title: Installation
|
||||
- local: clis
|
||||
title: Get started with Command Line Interfaces (CLIs)
|
||||
- local: how_to_train
|
||||
title: PPO Training FAQ
|
||||
- local: use_model
|
||||
@ -29,8 +31,14 @@
|
||||
title: Best of N Sampling
|
||||
- local: dpo_trainer
|
||||
title: DPO Trainer
|
||||
- local: kto_trainer
|
||||
title: KTO Trainer
|
||||
- local: cpo_trainer
|
||||
title: CPO Trainer
|
||||
- local: ddpo_trainer
|
||||
title: Denoising Diffusion Policy Optimization
|
||||
- local: orpo_trainer
|
||||
title: ORPO Trainer
|
||||
- local: iterative_sft_trainer
|
||||
title: Iterative Supervised Fine-Tuning
|
||||
- local: text_environments
|
||||
|
119
docs/source/clis.mdx
Normal file
119
docs/source/clis.mdx
Normal file
@ -0,0 +1,119 @@
|
||||
# Command Line Interfaces (CLIs)
|
||||
|
||||
You can use TRL to fine-tune your Language Model with Supervised Fine-Tuning (SFT) or Direct Policy Optimization (DPO) or even chat with your model using the TRL CLIs.
|
||||
|
||||
Currently supported CLIs are:
|
||||
|
||||
- `trl sft`: fine-tune a LLM on a text/instruction dataset
|
||||
- `trl dpo`: fine-tune a LLM with DPO on a preference dataset
|
||||
- `trl chat`: quickly spin up a LLM fine-tuned for chatting
|
||||
|
||||
## Fine-tuning with the CLI
|
||||
|
||||
Before getting started, pick up a Language Model from Hugging Face Hub. Supported models can be found with the filter "text-generation" within models. Also make sure to pick up a relevant dataset for your task.
|
||||
|
||||
Before using the `sft` or `dpo` commands make sure to run:
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
and pick up the right configuration for your training setup (single / multi-GPU, DeepSpeed, etc.). Make sure to complete all steps of `accelerate config` before running any CLI command.
|
||||
|
||||
We also recommend you passing a YAML config file to configure your training protocol. Below is a simple example of a YAML file that you can use for training your models with `trl sft` command.
|
||||
|
||||
```yaml
|
||||
model_name_or_path:
|
||||
HuggingFaceM4/tiny-random-LlamaForCausalLM
|
||||
dataset_name:
|
||||
imdb
|
||||
dataset_text_field:
|
||||
text
|
||||
report_to:
|
||||
none
|
||||
learning_rate:
|
||||
0.0001
|
||||
lr_scheduler_type:
|
||||
cosine
|
||||
```
|
||||
|
||||
Save that config in a `.yaml` and get directly started ! Note you can overwrite the arguments from the config file by explicitly passing them to the CLI, e.g.:
|
||||
|
||||
```bash
|
||||
trl sft --config example_config.yaml --output_dir test-trl-cli --lr_scheduler_type cosine_with_restarts
|
||||
```
|
||||
|
||||
Will force-use `cosine_with_restarts` for `lr_scheduler_type`.
|
||||
|
||||
### Supported Arguments
|
||||
|
||||
We do support all arguments from `transformers.TrainingArguments`, for loading your model, we support all arguments from `~trl.ModelConfig`:
|
||||
|
||||
[[autodoc]] ModelConfig
|
||||
|
||||
You can pass any of these arguments either to the CLI or the YAML file.
|
||||
|
||||
### Supervised Fine-tuning (SFT)
|
||||
|
||||
Follow the basic instructions above and run `trl sft --output_dir <output_dir> <*args>`:
|
||||
|
||||
```bash
|
||||
trl sft --model_name_or_path facebook/opt-125m --dataset_name imdb --output_dir opt-sft-imdb
|
||||
```
|
||||
|
||||
The SFT CLI is based on the `examples/scripts/sft.py` script.
|
||||
|
||||
### Direct Policy Optimization (DPO)
|
||||
|
||||
To use the DPO CLI, you need to have a dataset in the TRL format such as
|
||||
|
||||
* TRL's Anthropic HH dataset: https://huggingface.co/datasets/trl-internal-testing/hh-rlhf-trl-style
|
||||
* TRL's OpenAI TL;DR summarization dataset: https://huggingface.co/datasets/trl-internal-testing/tldr-preference-trl-style
|
||||
|
||||
These datasets always have at least three columns `prompt, chosen, rejected`:
|
||||
|
||||
* `prompt` is a list of strings.
|
||||
* `chosen` is the chosen response in [chat format](https://huggingface.co/docs/transformers/main/en/chat_templating)
|
||||
* `rejected` is the rejected response [chat format](https://huggingface.co/docs/transformers/main/en/chat_templating)
|
||||
|
||||
|
||||
To do a quick start, you can run the following command:
|
||||
|
||||
```bash
|
||||
trl dpo --model_name_or_path facebook/opt-125m --output_dir trl-hh-rlhf --dataset_name trl-internal-testing/hh-rlhf-trl-style
|
||||
```
|
||||
|
||||
|
||||
The DPO CLI is based on the `examples/scripts/dpo.py` script.
|
||||
|
||||
|
||||
#### Custom preference dataset
|
||||
|
||||
Format the dataset into TRL format (you can adapt the `examples/datasets/anthropic_hh.py`):
|
||||
|
||||
```bash
|
||||
python examples/datasets/anthropic_hh.py --push_to_hub --hf_entity your-hf-org
|
||||
```
|
||||
|
||||
## Chat interface
|
||||
|
||||
The chat CLI lets you quickly load the model and talk to it. Simply run the following:
|
||||
|
||||
```bash
|
||||
trl chat --model_name_or_path Qwen/Qwen1.5-0.5B-Chat
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> To use the chat CLI with the developer installation, you must run `make dev`
|
||||
>
|
||||
|
||||
Note that the chat interface relies on the tokenizer's [chat template](https://huggingface.co/docs/transformers/chat_templating) to format the inputs for the model. Make sure your tokenizer has a chat template defined.
|
||||
|
||||
Besides talking to the model there are a few commands you can use:
|
||||
|
||||
- **clear**: clears the current conversation and start a new one
|
||||
- **example {NAME}**: load example named `{NAME}` from the config and use it as the user input
|
||||
- **set {SETTING_NAME}={SETTING_VALUE};**: change the system prompt or generation settings (multiple settings are separated by a ';').
|
||||
- **reset**: same as clear but also resets the generation configs to defaults if they have been changed by **set**
|
||||
- **save {SAVE_NAME} (optional)**: save the current chat and settings to file by default to `./chat_history/{MODEL_NAME}/chat_{DATETIME}.yaml` or `{SAVE_NAME}` if provided
|
||||
- **exit**: closes the interface
|
||||
|
||||
The default examples are defined in `examples/scripts/config/default_chat_config.yaml` but you can pass your own with `--config CONFIG_FILE` where you can also specify the default generation parameters.
|
102
docs/source/cpo_trainer.mdx
Normal file
102
docs/source/cpo_trainer.mdx
Normal file
@ -0,0 +1,102 @@
|
||||
# CPO Trainer
|
||||
|
||||
Contrastive Preference Optimization (CPO) as introduced in the paper [Contrastive Preference Optimization: Pushing the Boundaries of LLM Performance in Machine Translation](https://huggingface.co/papers/2401.08417) by Haoran Xu, Amr Sharaf, Yunmo Chen, Weiting Tan, Lingfeng Shen, Benjamin Van Durme, Kenton Murray, and Young Jin Kim. At a high-level, CPO trains models to
|
||||
avoid generating adequate, but not perfect translations in Machine Translation (MT) tasks. However, CPO is a general approximation to the DPO loss and can be applied to other domains like chat.
|
||||
|
||||
CPO aims to mitigate two fundamental shortcomings of SFT. First, SFT’s methodology of minimizing the discrepancy between predicted outputs and gold-standard references inherently caps model performance at the quality level of the training data. Secondly, SFT lacks a mechanism to prevent the model from rejecting mistakes in translations. The CPO objective is derived from the DPO objective.
|
||||
|
||||
## Expected dataset format
|
||||
|
||||
The CPO trainer expects a format identical to the DPO trainer, which should include three entries. These entries should be named as follows:
|
||||
|
||||
- `prompt`
|
||||
- `chosen`
|
||||
- `rejected`
|
||||
|
||||
for example:
|
||||
|
||||
```py
|
||||
cpo_dataset_dict = {
|
||||
"prompt": [
|
||||
"hello",
|
||||
"how are you",
|
||||
"What is your name?",
|
||||
"What is your name?",
|
||||
"Which is the best programming language?",
|
||||
"Which is the best programming language?",
|
||||
"Which is the best programming language?",
|
||||
],
|
||||
"chosen": [
|
||||
"hi nice to meet you",
|
||||
"I am fine",
|
||||
"My name is Mary",
|
||||
"My name is Mary",
|
||||
"Python",
|
||||
"Python",
|
||||
"Java",
|
||||
],
|
||||
"rejected": [
|
||||
"leave me alone",
|
||||
"I am not fine",
|
||||
"Whats it to you?",
|
||||
"I dont have a name",
|
||||
"Javascript",
|
||||
"C++",
|
||||
"C++",
|
||||
],
|
||||
}
|
||||
```
|
||||
where the `prompt` contains the context inputs, `chosen` contains the corresponding chosen responses and `rejected` contains the corresponding negative (rejected) responses. As can be seen a prompt can have multiple responses and this is reflected in the entries being repeated in the dictionary's value arrays.
|
||||
|
||||
|
||||
## Expected model format
|
||||
The CPO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that expects `AutoModelForCausalLMWithValueHead` for the value function.
|
||||
|
||||
## Using the `CPOTrainer`
|
||||
For a detailed example have a look at the `examples/scripts/cpo.py` script. At a high level we need to initialize the `CPOTrainer` with a `model` we wish to train. **Note that CPOTrainer eliminates the need to use the reference model, simplifying the optimization process.** The `beta` refers to the hyperparameter of the implicit reward, and the dataset contains the 3 entries listed above.
|
||||
|
||||
```py
|
||||
cpo_config = CPOConfig(
|
||||
beta=0.1,
|
||||
)
|
||||
|
||||
cpo_trainer = CPOTrainer(
|
||||
model,
|
||||
args=cpo_config,
|
||||
train_dataset=train_dataset,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
```
|
||||
After this one can then call:
|
||||
|
||||
```py
|
||||
cpo_trainer.train()
|
||||
```
|
||||
|
||||
## Loss functions
|
||||
|
||||
Given the preference data, the `CPOTrainer` uses the sigmoid loss on the normalized likelihood via the `logsigmoid` to fit a logistic regression.
|
||||
|
||||
The [RSO](https://arxiv.org/abs/2309.06657) authors propose to use a hinge loss on the normalized likelihood from the [SLiC](https://arxiv.org/abs/2305.10425) paper. The `CPOTrainer` can be switched to this loss via the `loss_type="hinge"` argument and the `beta` in this case is the reciprocal of the margin.
|
||||
|
||||
The [IPO](https://arxiv.org/abs/2310.12036) authors provide a deeper theoretical understanding of the CPO algorithms and identify an issue with overfitting and propose an alternative loss which can be used via the `loss_type="ipo"` argument to the trainer. Note that the `beta` parameter is the reciprocal of the gap between the log-likelihood ratios of the chosen vs the rejected completion pair and thus the smaller the `beta` the larger this gaps is. As per the paper the loss is averaged over log-likelihoods of the completion (unlike CPO which is summed only).
|
||||
|
||||
|
||||
## Logging
|
||||
|
||||
While training and evaluating we record the following reward metrics:
|
||||
|
||||
* `rewards/chosen`: the mean log probabilities of the policy model for the chosen responses scaled by beta
|
||||
* `rewards/rejected`: the mean log probabilities of the policy model for the rejected responses scaled by beta
|
||||
* `rewards/accuracies`: mean of how often the chosen rewards are > than the corresponding rejected rewards
|
||||
* `rewards/margins`: the mean difference between the chosen and corresponding rejected rewards
|
||||
* `nll_loss`: the mean negative log likelihood loss of the policy model for the chosen responses
|
||||
|
||||
## CPOTrainer
|
||||
|
||||
[[autodoc]] CPOTrainer
|
||||
|
||||
|
||||
## CPOConfig
|
||||
|
||||
[[autodoc]] CPOConfig
|
@ -2,9 +2,24 @@
|
||||
|
||||
TRL supports the DPO Trainer for training language models from preference data, as described in the paper [Direct Preference Optimization: Your Language Model is Secretly a Reward Model](https://arxiv.org/abs/2305.18290) by Rafailov et al., 2023. For a full example have a look at [`examples/scripts/dpo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/dpo.py).
|
||||
|
||||
|
||||
The first step as always is to train your SFT model, to ensure the data we train on is in-distribution for the DPO algorithm.
|
||||
|
||||
## How DPO works
|
||||
|
||||
Fine-tuning a language model via DPO consists of two steps and is easier than PPO:
|
||||
|
||||
1. **Data collection**: Gather a preference dataset with positive and negative selected pairs of generation, given a prompt.
|
||||
2. **Optimization**: Maximize the log-likelihood of the DPO loss directly.
|
||||
|
||||
DPO-compatible datasets can be found with [the tag `dpo` on Hugging Face Hub](https://huggingface.co/datasets?other=dpo).
|
||||
|
||||
This process is illustrated in the sketch below (from [figure 1 of the original paper](https://arxiv.org/pdf/2305.18290.pdf)):
|
||||
|
||||
<img width="835" alt="Screenshot 2024-03-19 at 12 39 41" src="https://github.com/huggingface/trl/assets/49240599/9150fac6-3d88-4ca2-8ec6-2a6f3473216d">
|
||||
|
||||
Read more about DPO algorithm in the [original paper](https://arxiv.org/pdf/2305.18290.pdf).
|
||||
|
||||
|
||||
## Expected dataset format
|
||||
|
||||
The DPO trainer expects a very specific format for the dataset. Since the model will be trained to directly optimize the preference of which sentence is the most relevant, given two sentences. We provide an example from the [`Anthropic/hh-rlhf`](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset below:
|
||||
@ -63,7 +78,7 @@ The DPO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that
|
||||
For a detailed example have a look at the `examples/scripts/dpo.py` script. At a high level we need to initialize the `DPOTrainer` with a `model` we wish to train, a reference `ref_model` which we will use to calculate the implicit rewards of the preferred and rejected response, the `beta` refers to the hyperparameter of the implicit reward, and the dataset contains the 3 entries listed above. Note that the `model` and `ref_model` need to have the same architecture (ie decoder only or encoder-decoder).
|
||||
|
||||
```py
|
||||
dpo_trainer = DPOTrainer(
|
||||
dpo_trainer = DPOTrainer(
|
||||
model,
|
||||
model_ref,
|
||||
args=training_args,
|
||||
@ -90,7 +105,7 @@ The [IPO](https://arxiv.org/abs/2310.12036) authors provide a deeper theoretical
|
||||
|
||||
The [cDPO](https://ericmitchell.ai/cdpo.pdf) is a tweak on the DPO loss where we assume that the preference labels are noisy with some probability that can be passed to the `DPOTrainer` via `label_smoothing` argument (between 0 and 0.5) and then a conservative DPO loss is used. Use the `loss_type="cdpo"` argument to the trainer to use it.
|
||||
|
||||
The [KTO](https://github.com/ContextualAI/HALOs/blob/main/assets/report.pdf) loss is derived to directly maximize the utility of LLM generations instead of the log-likelihood of preferences. Thus the dataset are not necessarily preferences but rather desirable vs undesirable completions. For paired preference data as required by the `DPOTrainer`, use the `loss_type="kto_pair"` argument to the trainer to utilize this loss, while for the more general case of desired and undesirable data, use the as of yet unimplemented `KTOTrainer`.
|
||||
The [KTO](https://arxiv.org/abs/2402.01306) authors directly maximize the utility of LLM generations instead of the log-likelihood of preferences. To use preference data with KTO, we recommend breaking up the n preferences into 2n examples and using [`KTOTrainer`](kto_trainer) (i.e., treating the data like an unpaired feedback dataset). Although it is possible to pass in `loss_type="kto_pair"` into DPOTrainer, this is a highly simplified version of KTO that we *do not recommend* in most cases. Please use [`KTOTrainer`](kto_trainer) when possible.
|
||||
|
||||
## Logging
|
||||
|
||||
@ -146,7 +161,7 @@ training_args = TrainingArguments(output_dir="./output")
|
||||
|
||||
dpo_trainer = DPOTrainer(
|
||||
model,
|
||||
model_ref=None,
|
||||
ref_model=None,
|
||||
args=training_args,
|
||||
beta=0.1,
|
||||
train_dataset=train_dataset,
|
||||
|
@ -35,6 +35,7 @@ Then, it is encouraged to launch jobs with `accelerate launch`!
|
||||
| File | Description |
|
||||
|------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------|
|
||||
| [`examples/scripts/sft.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/sft.py) | This script shows how to use the `SFTTrainer` to fine tune a model or adapters into a target dataset. |
|
||||
| [`examples/scripts/vsft_llava.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/vsft_llava.py) | This script shows how to use the `SFTTrainer` to fine tune a Vision Language Model in a chat setting, the script has been tested on a llava1.5 model so users may see unexpected behaviour in other model architectures. |
|
||||
| [`examples/scripts/reward_modeling.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/reward_modeling.py) | This script shows how to use the `RewardTrainer` to train a reward model on your own dataset. |
|
||||
| [`examples/scripts/ppo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo.py) | This script shows how to use the `PPOTrainer` to fine-tune a sentiment analysis model using IMDB dataset |
|
||||
| [`examples/scripts/ppo_multi_adapter.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo_multi_adapter.py) | This script shows how to use the `PPOTrainer` to train a single base model with multiple adapters. Requires you to run the example script with the reward model training beforehand. |
|
||||
|
@ -59,7 +59,7 @@ Debugging the RL pipeline can be challenging due to its complexity. Here are som
|
||||
- **Start from a working example**: Begin with a working example from the trl repository and gradually modify it to fit your specific use-case. Changing everything at once can make it difficult to identify the source of potential issues. For example, you can start by replacing the model in the example and once you figure out the best hyperparameters try to switch to your dataset and reward model. If you change everything at once you won't know where a potential problem comes from.
|
||||
- **Start small, scale later**: Training large models can be very slow and take several hours or days until you see any improvement. For debugging this is not a convenient timescale so try to use small model variants during the development phase and scale up once that works. That being said you sometimes have to be careful as small models might not have the capacity to solve a complicated task either.
|
||||
- **Start simple**: Try to start with a minimal example and build complexity from there. Your use-case might require for example a complicated reward function consisting of many different rewards - try to use one signal first and see if you can optimize that and then add more complexity after that.
|
||||
- **Inspect the generations**: It's always a good idea to inspect what the model is generating. Maybe there is a big in your post-processing or your prompt. Due to bad settings you might cut-off generations too soon. These things are very hard to see on the metrics but very obvious if you look at the generations.
|
||||
- **Inspect the generations**: It's always a good idea to inspect what the model is generating. Maybe there is a bug in your post-processing or your prompt. Due to bad settings you might cut-off generations too soon. These things are very hard to see on the metrics but very obvious if you look at the generations.
|
||||
- **Inspect the reward model**: If you reward is not improving over time maybe there's an issue with the reward model. You can look at extreme cases to see if it does what it should: e.g. in the sentiment case you can check if simple positive and negative examples really get different rewards. And you can look at the distribution of your dataset. Finally, maybe the reward is dominated by the query which the model can't affect so you might need to normalize this (e.g. reward of query+response minus reward of the query).
|
||||
|
||||
These are just a few tips that we find helpful - if you have more useful tricks feel free to open a PR to add them as well!
|
||||
|
93
docs/source/kto_trainer.mdx
Normal file
93
docs/source/kto_trainer.mdx
Normal file
@ -0,0 +1,93 @@
|
||||
# KTO Trainer
|
||||
|
||||
TRL supports the Kahneman-Tversky Optimization (KTO) Trainer for aligning language models with binary feedback data (e.g., upvote/downvote), as described in the [paper](https://arxiv.org/abs/2402.01306) by Kawin Ethayarajh, Winnie Xu, Niklas Muennighoff, Dan Jurafsky, and Douwe Kiela.
|
||||
For a full example have a look at [`examples/scripts/kto.py`].
|
||||
|
||||
Depending on how good your base model is, you may or may not need to do SFT before KTO.
|
||||
This is different from standard RLHF and DPO, which always require SFT.
|
||||
|
||||
## Expected dataset format
|
||||
|
||||
The KTO trainer expects a very specific format for the dataset as it does not require pairwise preferences. Since the model will be trained to directly optimize examples that consist of a prompt, model completion, and a label to indicate whether the completion is "good" or "bad", we expect a dataset with the following columns:
|
||||
|
||||
- `prompt`
|
||||
- `completion`
|
||||
- `label`
|
||||
|
||||
for example:
|
||||
|
||||
```
|
||||
kto_dataset_dict = {
|
||||
"prompt": [
|
||||
"Hey, hello",
|
||||
"How are you",
|
||||
"What is your name?",
|
||||
"What is your name?",
|
||||
"Which is the best programming language?",
|
||||
"Which is the best programming language?",
|
||||
"Which is the best programming language?",
|
||||
],
|
||||
"completion": [
|
||||
"hi nice to meet you",
|
||||
"leave me alone",
|
||||
"I don't have a name",
|
||||
"My name is Mary",
|
||||
"Python",
|
||||
"C++",
|
||||
"Java",
|
||||
],
|
||||
"label": [
|
||||
True,
|
||||
False,
|
||||
False,
|
||||
True,
|
||||
True,
|
||||
False,
|
||||
False,
|
||||
],
|
||||
}
|
||||
```
|
||||
|
||||
where the `prompt` contains the context inputs, `completion` contains the corresponding responses and `label` contains the corresponding flag that indicates if the generated completion is desired (`True`) or undesired (`False`).
|
||||
A prompt can have multiple responses and this is reflected in the entries being repeated in the dictionary's value arrays.
|
||||
|
||||
## Expected model format
|
||||
The KTO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that expects `AutoModelForCausalLMWithValueHead` for the value function.
|
||||
|
||||
## Using the `KTOTrainer`
|
||||
|
||||
For a detailed example have a look at the `examples/scripts/kto.py` script. At a high level we need to initialize the `KTOTrainer` with a `model` we wish to train and a reference `ref_model` which we will use to calculate the implicit rewards of the preferred and rejected response.
|
||||
|
||||
The `beta` refers to the hyperparameter of the implicit reward, and the dataset contains the 3 entries listed above. Note that the `model` and `ref_model` need to have the same architecture (ie decoder only or encoder-decoder).
|
||||
|
||||
The `desirable_weight` and `undesirable_weight` refer to the weights placed on the losses for desirable/positive and undesirable/negative examples.
|
||||
By default, they are both 1. However, if you have more of one or the other, then you should upweight the less common type such that the ratio of (`desirable_weight` * number of positives) to (`undesirable_weight` * number of negatives) is in the range 1:1 to 4:3.
|
||||
|
||||
```py
|
||||
training_args = KTOConfig(
|
||||
beta=0.1,
|
||||
desirable_weight=1.0,
|
||||
undesirable_weight=1.0,
|
||||
)
|
||||
|
||||
kto_trainer = KTOTrainer(
|
||||
model,
|
||||
model_ref,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
```
|
||||
After this one can then call:
|
||||
|
||||
```py
|
||||
kto_trainer.train()
|
||||
```
|
||||
|
||||
## KTOTrainer
|
||||
|
||||
[[autodoc]] KTOTrainer
|
||||
|
||||
## KTOConfig
|
||||
|
||||
[[autodoc]] KTOConfig
|
98
docs/source/orpo_trainer.md
Normal file
98
docs/source/orpo_trainer.md
Normal file
@ -0,0 +1,98 @@
|
||||
# ORPO Trainer
|
||||
|
||||
[Odds Ratio Preference Optimization](https://arxiv.org/abs/2403.07691) (ORPO) by Jiwoo Hong, Noah Lee, and James Thorne studies the crucial role of SFT within the context of preference alignment. Using preference data the method posits that a minor penalty for the disfavored generation together with a strong adaption signal to the chosen response via a simple log odds ratio term appended to the NLL loss is sufficient for preference-aligned SFT.
|
||||
|
||||
Thus ORPO is a reference model-free preference optimization algorithm eliminating the necessity for an additional preference alignment phase thus saving compute and memory.
|
||||
|
||||
The official code can be found [xfactlab/orpo](https://github.com/xfactlab/orpo).
|
||||
|
||||
## Expected dataset format
|
||||
|
||||
The ORPO trainer expects a format identical to the DPO trainer, which should include three entries. These entries should be named as follows:
|
||||
|
||||
- `prompt`
|
||||
- `chosen`
|
||||
- `rejected`
|
||||
|
||||
for example:
|
||||
|
||||
```py
|
||||
orpo_dataset_dict = {
|
||||
"prompt": [
|
||||
"hello",
|
||||
"how are you",
|
||||
"What is your name?",
|
||||
"What is your name?",
|
||||
"Which is the best programming language?",
|
||||
"Which is the best programming language?",
|
||||
"Which is the best programming language?",
|
||||
],
|
||||
"chosen": [
|
||||
"hi nice to meet you",
|
||||
"I am fine",
|
||||
"My name is Mary",
|
||||
"My name is Mary",
|
||||
"Python",
|
||||
"Python",
|
||||
"Java",
|
||||
],
|
||||
"rejected": [
|
||||
"leave me alone",
|
||||
"I am not fine",
|
||||
"Whats it to you?",
|
||||
"I dont have a name",
|
||||
"Javascript",
|
||||
"C++",
|
||||
"C++",
|
||||
],
|
||||
}
|
||||
```
|
||||
where the `prompt` contains the context inputs, `chosen` contains the corresponding chosen responses and `rejected` contains the corresponding negative (rejected) responses. Note that a prompt can have multiple responses and this is reflected in the entries being repeated in the dictionary's value arrays.
|
||||
|
||||
## Expected model format
|
||||
The ORPO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that expects `AutoModelForCausalLMWithValueHead` for the value function.
|
||||
|
||||
## Using the `ORPOTrainer`
|
||||
For a detailed example have a look at the `examples/scripts/orpo.py` script. At a high level we need to initialize the `ORPOTrainer` with a `model` we wish to train. **Note that ORPOTrainer eliminates the need to use the reference model, simplifying the optimization process.** The `beta` refers to the hyperparameter `lambda` in eq. (6) of the paper and refers to the weighting of the relative odd ratio loss in the standard cross-entropy loss used for SFT.
|
||||
|
||||
```py
|
||||
orpo_config = ORPOConfig(
|
||||
beta=0.1, # the lambda/alpha hyperparameter in the paper/code
|
||||
)
|
||||
|
||||
orpo_trainer = ORPOTrainer(
|
||||
model,
|
||||
args=orpo_config,
|
||||
train_dataset=train_dataset,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
```
|
||||
After this one can then call:
|
||||
|
||||
```py
|
||||
orpo_trainer.train()
|
||||
```
|
||||
|
||||
## Logging
|
||||
|
||||
While training and evaluating we record the following reward metrics:
|
||||
|
||||
* `rewards/chosen`: the mean log probabilities of the policy model for the chosen responses scaled by beta
|
||||
* `rewards/rejected`: the mean log probabilities of the policy model for the rejected responses scaled by beta
|
||||
* `rewards/accuracies`: mean of how often the chosen rewards are > than the corresponding rejected rewards
|
||||
* `rewards/margins`: the mean difference between the chosen and corresponding rejected rewards
|
||||
|
||||
* `log_odds_chosen`: the mean log odds ratio of the chosen responses over the rejected responses
|
||||
|
||||
* `log_odds_ratio`: the mean of the `log(sigmoid(log_odds_chosen))`
|
||||
|
||||
* `nll_loss`: the mean negative log likelihood loss from the SFT part of the loss over chosen responses
|
||||
|
||||
## ORPOTrainer
|
||||
|
||||
[[autodoc]] ORPOTrainer
|
||||
|
||||
|
||||
## ORPOConfig
|
||||
|
||||
[[autodoc]] ORPOConfig
|
@ -4,6 +4,21 @@ TRL supports the [PPO](https://arxiv.org/abs/1707.06347) Trainer for training la
|
||||
|
||||
The first step is to train your SFT model (see the [SFTTrainer](sft_trainer)), to ensure the data we train on is in-distribution for the PPO algorithm. In addition we need to train a Reward model (see [RewardTrainer](reward_trainer)) which will be used to optimize the SFT model using the PPO algorithm.
|
||||
|
||||
## How PPO works
|
||||
|
||||
Fine-tuning a language model via PPO consists of roughly three steps:
|
||||
|
||||
1. **Rollout**: The language model generates a response or continuation based on query which could be the start of a sentence.
|
||||
2. **Evaluation**: The query and response are evaluated with a function, model, human feedback or some combination of them. The important thing is that this process should yield a scalar value for each query/response pair.
|
||||
3. **Optimization**: This is the most complex part. In the optimisation step the query/response pairs are used to calculate the log-probabilities of the tokens in the sequences. This is done with the model that is trained and a reference model, which is usually the pre-trained model before fine-tuning. The KL-divergence between the two outputs is used as an additional reward signal to make sure the generated responses don't deviate too far from the reference language model. The active language model is then trained with PPO.
|
||||
|
||||
This process is illustrated in the sketch below:
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/trl_overview.png" width="800">
|
||||
<p style="text-align: center;"> <b>Figure:</b> Sketch of the workflow. </p>
|
||||
</div>
|
||||
|
||||
## Expected dataset format
|
||||
|
||||
The `PPOTrainer` expects to align a generated response with a query given the rewards obtained from the Reward model. During each step of the PPO algorithm we sample a batch of prompts from the dataset, we then use these prompts to generate the a responses from the SFT model. Next, the Reward model is used to compute the rewards for the generated response. Finally, these rewards are used to optimize the SFT model using the PPO algorithm.
|
||||
@ -115,7 +130,10 @@ We can then loop over all examples in the dataset and generate a response for ea
|
||||
|
||||
```py
|
||||
from tqdm import tqdm
|
||||
for epoch in tqdm(range(ppo_trainer.config.ppo_epochs), "epoch: "):
|
||||
|
||||
|
||||
epochs = 10
|
||||
for epoch in tqdm(range(epochs), "epoch: "):
|
||||
for batch in tqdm(ppo_trainer.dataloader):
|
||||
query_tensors = batch["input_ids"]
|
||||
|
||||
@ -133,7 +151,7 @@ for epoch in tqdm(range(ppo_trainer.config.ppo_epochs), "epoch: "):
|
||||
ppo_trainer.log_stats(stats, batch, rewards)
|
||||
|
||||
#### Save model
|
||||
ppo_trainer.save_model("my_ppo_model")
|
||||
ppo_trainer.save_pretrained("my_ppo_model")
|
||||
```
|
||||
|
||||
## Logging
|
||||
|
@ -3,6 +3,7 @@
|
||||
Supervised fine-tuning (or SFT for short) is a crucial step in RLHF. In TRL we provide an easy-to-use API to create your SFT models and train them with few lines of code on your dataset.
|
||||
|
||||
Check out a complete flexible example at [`examples/scripts/sft.py`](https://github.com/huggingface/trl/tree/main/examples/scripts/sft.py).
|
||||
Experimental support for Vision Language Models is also included in the example [`examples/scripts/vsft_llava.py`](https://github.com/huggingface/trl/tree/main/examples/scripts/vsft_llava.py).
|
||||
|
||||
## Quickstart
|
||||
|
||||
@ -271,6 +272,7 @@ trainer.train()
|
||||
```
|
||||
|
||||
Note that if you use a packed dataset and if you pass `max_steps` in the training arguments you will probably train your models for more than few epochs, depending on the way you have configured the packed dataset and the training protocol. Double check that you know and understand what you are doing.
|
||||
If you don't want to pack your `eval_dataset`, you can pass `eval_packing=False` to the `SFTTrainer` init method.
|
||||
|
||||
#### Customize your prompts using packed dataset
|
||||
|
||||
@ -604,6 +606,12 @@ You may experience some issues with GPTQ Quantization after completing training.
|
||||
|
||||
[[autodoc]] SFTTrainer
|
||||
|
||||
## ConstantLengthDataset
|
||||
## Datasets
|
||||
|
||||
In the SFTTrainer we smartly support `datasets.IterableDataset` in addition to other style datasets. This is useful if you are using large corpora that you do not want to save all to disk. The data will be tokenized and processed on the fly, even when packing is enabled.
|
||||
|
||||
Additionally, in the SFTTrainer, we support pre-tokenized datasets if they are `datasets.Dataset` or `datasets.IterableDataset`. In other words, if such a dataset has a column of `input_ids`, no further processing (tokenization or packing) will be done, and the dataset will be used as-is. This can be useful if you have pretokenized your dataset outside of this script and want to re-use it directly.
|
||||
|
||||
### ConstantLengthDataset
|
||||
|
||||
[[autodoc]] trainer.ConstantLengthDataset
|
||||
|
@ -4,6 +4,47 @@ At TRL we support PPO (Proximal Policy Optimisation) with an implementation that
|
||||
The Trainer and model classes are largely inspired from `transformers.Trainer` and `transformers.AutoModel` classes and adapted for RL.
|
||||
We also support a `RewardTrainer` that can be used to train a reward model.
|
||||
|
||||
|
||||
## CPOConfig
|
||||
|
||||
[[autodoc]] CPOConfig
|
||||
|
||||
## CPOTrainer
|
||||
|
||||
[[autodoc]] CPOTrainer
|
||||
|
||||
## DDPOConfig
|
||||
|
||||
[[autodoc]] DDPOConfig
|
||||
|
||||
## DDPOTrainer
|
||||
|
||||
[[autodoc]] DDPOTrainer
|
||||
|
||||
## DPOTrainer
|
||||
|
||||
[[autodoc]] DPOTrainer
|
||||
|
||||
## IterativeSFTTrainer
|
||||
|
||||
[[autodoc]] IterativeSFTTrainer
|
||||
|
||||
## KTOConfig
|
||||
|
||||
[[autodoc]] KTOConfig
|
||||
|
||||
## KTOTrainer
|
||||
|
||||
[[autodoc]] KTOTrainer
|
||||
|
||||
## ORPOConfig
|
||||
|
||||
[[autodoc]] ORPOConfig
|
||||
|
||||
## ORPOTrainer
|
||||
|
||||
[[autodoc]] ORPOTrainer
|
||||
|
||||
## PPOConfig
|
||||
|
||||
[[autodoc]] PPOConfig
|
||||
@ -24,22 +65,6 @@ We also support a `RewardTrainer` that can be used to train a reward model.
|
||||
|
||||
[[autodoc]] SFTTrainer
|
||||
|
||||
## DPOTrainer
|
||||
|
||||
[[autodoc]] DPOTrainer
|
||||
|
||||
## DDPOConfig
|
||||
|
||||
[[autodoc]] DDPOConfig
|
||||
|
||||
## DDPOTrainer
|
||||
|
||||
[[autodoc]] DDPOTrainer
|
||||
|
||||
## IterativeSFTTrainer
|
||||
|
||||
[[autodoc]] IterativeSFTTrainer
|
||||
|
||||
## set_seed
|
||||
|
||||
[[autodoc]] set_seed
|
||||
|
20
example_config.yaml
Normal file
20
example_config.yaml
Normal file
@ -0,0 +1,20 @@
|
||||
# This is an example configuration file of TRL CLI, you can use it for
|
||||
# SFT like that: `trl sft --config config.yaml --output_dir test-sft`
|
||||
# The YAML file supports environment variables by adding an `env` field
|
||||
# as below
|
||||
|
||||
# env:
|
||||
# CUDA_VISIBLE_DEVICES: 0
|
||||
|
||||
model_name_or_path:
|
||||
HuggingFaceM4/tiny-random-LlamaForCausalLM
|
||||
dataset_name:
|
||||
imdb
|
||||
dataset_text_field:
|
||||
text
|
||||
report_to:
|
||||
none
|
||||
learning_rate:
|
||||
0.0001
|
||||
lr_scheduler_type:
|
||||
cosine
|
122
examples/datasets/anthropic_hh.py
Normal file
122
examples/datasets/anthropic_hh.py
Normal file
@ -0,0 +1,122 @@
|
||||
import multiprocessing
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub.repocard import RepoCard
|
||||
from transformers import HfArgumentParser
|
||||
|
||||
|
||||
"""
|
||||
# debug
|
||||
python -i examples/datasets/anthropic_hh.py --debug --push_to_hub
|
||||
# actual push
|
||||
python examples/datasets/anthropic_hh.py --push_to_hub --hf_entity trl-internal-testing
|
||||
"""
|
||||
|
||||
|
||||
api = HfApi()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
debug: Optional[bool] = field(default=False, metadata={"help": "Enable debug mode"})
|
||||
hf_entity: Optional[str] = field(default=None, metadata={"help": "The Hugging Face entity to use"})
|
||||
hf_repo_id: Optional[str] = field(default="hh-rlhf-trl-style", metadata={"help": "The Hugging Face repository ID"})
|
||||
revision: Optional[str] = field(default="0.1.0", metadata={"help": "The revision of the repository"})
|
||||
update_main_revision: Optional[bool] = field(
|
||||
default=True, metadata={"help": "Update the main revision of the repository"}
|
||||
)
|
||||
push_to_hub: Optional[bool] = field(default=False, metadata={"help": "Push the dataset to the Hugging Face Hub"})
|
||||
|
||||
|
||||
# GPT-4 generated 😄 Define a function to process the input and extract the dialogue into structured format
|
||||
def extract_dialogue(input_text):
|
||||
# Split the input by lines and initialize variables
|
||||
lines = input_text.strip().split("\n\n")
|
||||
dialogue_list = []
|
||||
|
||||
# Iterate through each line and extract the dialogue
|
||||
for line in lines:
|
||||
# Check if the line starts with "Human" or "Assistant" and split accordingly
|
||||
if line.startswith("Human:"):
|
||||
role = "user"
|
||||
content = line.replace("Human: ", "").strip()
|
||||
elif line.startswith("Assistant:"):
|
||||
role = "assistant"
|
||||
content = line.replace("Assistant: ", "").strip()
|
||||
else:
|
||||
# If the line doesn't start with "Human" or "Assistant", it's part of the previous message's content
|
||||
# Append it to the last message's content
|
||||
dialogue_list[-1]["content"] += "\n\n" + line.strip()
|
||||
continue
|
||||
|
||||
# Append the extracted dialogue piece to the list
|
||||
dialogue_list.append({"role": role, "content": content})
|
||||
|
||||
return dialogue_list
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = HfArgumentParser(ScriptArguments).parse_args_into_dataclasses()[0]
|
||||
if args.hf_entity is None:
|
||||
args.hf_entity = api.whoami()["name"]
|
||||
full_repo_id = f"{args.hf_entity}/{args.hf_repo_id}"
|
||||
ds = load_dataset("Anthropic/hh-rlhf")
|
||||
if args.debug:
|
||||
for key in ds:
|
||||
ds[key] = ds[key].select(range(50))
|
||||
|
||||
def process(row):
|
||||
row["chosen"] = extract_dialogue(row["chosen"])
|
||||
row["rejected"] = extract_dialogue(row["rejected"])
|
||||
row["prompt"] = row["chosen"][0]["content"]
|
||||
return row
|
||||
|
||||
ds = ds.map(
|
||||
process,
|
||||
num_proc=1 if args.debug else multiprocessing.cpu_count(),
|
||||
load_from_cache_file=False,
|
||||
)
|
||||
if args.push_to_hub:
|
||||
revisions = ["main"] if args.update_main_revision else []
|
||||
revisions.append(args.revision)
|
||||
|
||||
# get the commnad used to run the script
|
||||
run_command = " ".join(["python"] + sys.argv)
|
||||
|
||||
for revision in revisions:
|
||||
ds.push_to_hub(full_repo_id, revision=revision)
|
||||
repo_full_url = f"https://huggingface.co/datasets/{full_repo_id}/tree/{revision}"
|
||||
|
||||
# get the name of the current file
|
||||
file_name = __file__.split("/")[-1]
|
||||
api.upload_file(
|
||||
path_or_fileobj=__file__,
|
||||
path_in_repo=file_name,
|
||||
revision=revision,
|
||||
repo_id=full_repo_id,
|
||||
repo_type="dataset",
|
||||
)
|
||||
|
||||
sft_card = RepoCard.load(
|
||||
full_repo_id,
|
||||
repo_type="dataset",
|
||||
)
|
||||
sft_card.text = f"""\
|
||||
# TRL's Anthropic HH Dataset
|
||||
|
||||
We preprocess the dataset using our standard `prompt, chosen, rejected` format.
|
||||
|
||||
|
||||
## Reproduce this dataset
|
||||
|
||||
1. Download the `{file_name}` from the {repo_full_url}.
|
||||
2. Run `{run_command}`
|
||||
"""
|
||||
sft_card.push_to_hub(
|
||||
full_repo_id,
|
||||
repo_type="dataset",
|
||||
)
|
113
examples/datasets/tldr_preference.py
Normal file
113
examples/datasets/tldr_preference.py
Normal file
@ -0,0 +1,113 @@
|
||||
import multiprocessing
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub.repocard import RepoCard
|
||||
from transformers import HfArgumentParser
|
||||
|
||||
|
||||
"""
|
||||
# debug
|
||||
python -i examples/datasets/tldr_preference.py --debug --push_to_hub
|
||||
# actual push
|
||||
python examples/datasets/tldr_preference.py --push_to_hub --hf_entity trl-internal-testing
|
||||
"""
|
||||
|
||||
|
||||
api = HfApi()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
debug: Optional[bool] = field(default=False, metadata={"help": "Enable debug mode"})
|
||||
hf_entity: Optional[str] = field(default=None, metadata={"help": "The Hugging Face entity to use"})
|
||||
hf_repo_id: Optional[str] = field(
|
||||
default="tldr-preference-trl-style", metadata={"help": "The Hugging Face repository ID"}
|
||||
)
|
||||
revision: Optional[str] = field(default="0.1.0", metadata={"help": "The revision of the repository"})
|
||||
update_main_revision: Optional[bool] = field(
|
||||
default=True, metadata={"help": "Update the main revision of the repository"}
|
||||
)
|
||||
push_to_hub: Optional[bool] = field(default=False, metadata={"help": "Push the dataset to the Hugging Face Hub"})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = HfArgumentParser(ScriptArguments).parse_args_into_dataclasses()[0]
|
||||
if args.hf_entity is None:
|
||||
args.hf_entity = api.whoami()["name"]
|
||||
full_repo_id = f"{args.hf_entity}/{args.hf_repo_id}"
|
||||
|
||||
ds = load_dataset("openai/summarize_from_feedback", "comparisons")
|
||||
if args.debug:
|
||||
for key in ds:
|
||||
ds[key] = ds[key].select(range(50))
|
||||
cnndm_batches = ["batch0_cnndm", "cnndm0", "cnndm2"]
|
||||
if not args.debug:
|
||||
ds["validation_cnndm"] = ds["validation"].filter(lambda x: x["batch"] in cnndm_batches)
|
||||
ds["validation"] = ds["validation"].filter(lambda x: x["batch"] not in cnndm_batches)
|
||||
|
||||
tldr_format_str = "SUBREDDIT: r/{subreddit}\n\nTITLE: {title}\n\nPOST: {post}\n\nTL;DR:"
|
||||
cnndm_format_str = "Article:\n{article}\n\nTL;DR:"
|
||||
|
||||
def process(row):
|
||||
format_str = cnndm_format_str if row["batch"] in cnndm_batches else tldr_format_str
|
||||
row["prompt"] = format_str.format(**row["info"])
|
||||
choice = row["choice"]
|
||||
chosen = row["summaries"][choice]["text"]
|
||||
rejected = row["summaries"][1 - choice]["text"]
|
||||
row["chosen"] = [{"role": "user", "content": row["prompt"]}, {"role": "assistant", "content": chosen}]
|
||||
row["rejected"] = [{"role": "user", "content": row["prompt"]}, {"role": "assistant", "content": rejected}]
|
||||
return row
|
||||
|
||||
ds = ds.map(
|
||||
process,
|
||||
num_proc=1 if args.debug else multiprocessing.cpu_count(),
|
||||
load_from_cache_file=False,
|
||||
)
|
||||
for key in ds: # reorder columns
|
||||
ds[key] = ds[key].select_columns(
|
||||
["prompt", "chosen", "rejected", "info", "summaries", "choice", "worker", "batch", "split", "extra"]
|
||||
)
|
||||
if args.push_to_hub:
|
||||
revisions = ["main"] if args.update_main_revision else []
|
||||
revisions.append(args.revision)
|
||||
|
||||
# get the commnad used to run the script
|
||||
run_command = " ".join(["python"] + sys.argv)
|
||||
|
||||
for revision in revisions:
|
||||
ds.push_to_hub(full_repo_id, revision=revision)
|
||||
repo_full_url = f"https://huggingface.co/datasets/{full_repo_id}/tree/{revision}"
|
||||
|
||||
# get the name of the current file
|
||||
file_name = __file__.split("/")[-1]
|
||||
api.upload_file(
|
||||
path_or_fileobj=__file__,
|
||||
path_in_repo=file_name,
|
||||
revision=revision,
|
||||
repo_id=full_repo_id,
|
||||
repo_type="dataset",
|
||||
)
|
||||
|
||||
sft_card = RepoCard.load(
|
||||
full_repo_id,
|
||||
repo_type="dataset",
|
||||
)
|
||||
sft_card.text = f"""\
|
||||
# TRL's TL;DR Preference Dataset
|
||||
|
||||
We preprocess the dataset using our standard `prompt, chosen, rejected` format.
|
||||
|
||||
|
||||
## Reproduce this dataset
|
||||
|
||||
1. Download the `{file_name}` from the {repo_full_url}.
|
||||
2. Run `{run_command}`
|
||||
"""
|
||||
sft_card.push_to_hub(
|
||||
full_repo_id,
|
||||
repo_type="dataset",
|
||||
)
|
42
examples/datasets/tokenize_ds.py
Normal file
42
examples/datasets/tokenize_ds.py
Normal file
@ -0,0 +1,42 @@
|
||||
import multiprocessing
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoTokenizer, HfArgumentParser
|
||||
|
||||
|
||||
"""
|
||||
python -i examples/datasets/tokenize_ds.py --debug --model HuggingFaceH4/zephyr-7b-beta
|
||||
python -i examples/datasets/tokenize_ds.py --debug --model gpt2
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
debug: Optional[bool] = field(default=False, metadata={"help": "Enable debug mode"})
|
||||
dataset: str = field(default="trl-internal-testing/hh-rlhf-trl-style", metadata={"help": "The dataset to load"})
|
||||
model: str = field(default="gpt2", metadata={"help": "The model to use for tokenization"})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = HfArgumentParser(ScriptArguments).parse_args_into_dataclasses()[0]
|
||||
ds = load_dataset(args.dataset)
|
||||
if args.debug:
|
||||
for key in ds:
|
||||
ds[key] = ds[key].select(range(50))
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model)
|
||||
if tokenizer.chat_template is None:
|
||||
tokenizer.chat_template = "{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\n\n'}}{% endfor %}{{ eos_token }}"
|
||||
|
||||
def process(row):
|
||||
row["chosen"] = tokenizer.apply_chat_template(row["chosen"], tokenize=False)
|
||||
row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False)
|
||||
return row
|
||||
|
||||
ds = ds.map(
|
||||
process,
|
||||
num_proc=1 if args.debug else multiprocessing.cpu_count(),
|
||||
load_from_cache_file=False,
|
||||
)
|
||||
print(ds["train"][0]["chosen"])
|
@ -15,6 +15,7 @@ from transformers import (
|
||||
Trainer,
|
||||
TrainerCallback,
|
||||
TrainingArguments,
|
||||
set_seed,
|
||||
)
|
||||
from transformers.utils import PaddingStrategy
|
||||
|
||||
@ -89,11 +90,14 @@ class ScriptArguments:
|
||||
default=False,
|
||||
metadata={"help": "Whether to run eval after the first step"},
|
||||
)
|
||||
seed: Optional[int] = field(
|
||||
default=0, metadata={"help": "Random seed that will be set at the beginning of training."}
|
||||
)
|
||||
|
||||
|
||||
parser = HfArgumentParser(ScriptArguments)
|
||||
script_args = parser.parse_args_into_dataclasses()[0]
|
||||
|
||||
set_seed(script_args.seed)
|
||||
# Load the human stack-exchange-paired dataset for tuning the reward model.
|
||||
train_dataset = load_dataset("lvwerra/stack-exchange-paired", data_dir="data/reward", split="train")
|
||||
if script_args.train_subset > 0:
|
||||
@ -129,7 +133,10 @@ training_args = TrainingArguments(
|
||||
logging_steps=10,
|
||||
optim=script_args.optim,
|
||||
lr_scheduler_type=script_args.lr_scheduler_type,
|
||||
seed=script_args.seed,
|
||||
)
|
||||
|
||||
|
||||
# Load the value-head model and tokenizer.
|
||||
tokenizer_name = script_args.tokenizer_name if script_args.tokenizer_name is not None else script_args.model_name
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_auth_token=True)
|
||||
|
@ -66,6 +66,7 @@ class ScriptArguments:
|
||||
)
|
||||
|
||||
adap_kl_ctrl: Optional[bool] = field(default=True, metadata={"help": "Use adaptive KL control, otherwise linear"})
|
||||
load_in_8bit: Optional[bool] = field(default=True, metadata={"help": "whether to load the model in 8bit"})
|
||||
|
||||
|
||||
parser = HfArgumentParser(ScriptArguments)
|
||||
@ -180,7 +181,7 @@ lora_config = LoraConfig(
|
||||
)
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
config.model_name,
|
||||
load_in_8bit=True,
|
||||
load_in_8bit=script_args.load_in_8bit,
|
||||
device_map={"": current_device},
|
||||
peft_config=lora_config,
|
||||
)
|
||||
@ -215,11 +216,13 @@ sentiment_pipe = pipeline(
|
||||
"sentiment-analysis",
|
||||
model=reward_model_name,
|
||||
device_map={"": current_device},
|
||||
model_kwargs={"load_in_8bit": True},
|
||||
model_kwargs={"load_in_8bit": script_args.load_in_8bit},
|
||||
tokenizer=tokenizer,
|
||||
return_token_type_ids=False,
|
||||
)
|
||||
|
||||
if sentiment_pipe.model.config.pad_token_id is None:
|
||||
sentiment_pipe.model.config.pad_token_id = sentiment_pipe.model.config.eos_token_id
|
||||
# We then define the arguments to pass to the `generate` function. These arguments
|
||||
# are passed to the `generate` function of the PPOTrainer, which is a wrapper around
|
||||
# the `generate` function of the trained model.
|
||||
|
@ -4,9 +4,10 @@ from dataclasses import dataclass, field
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from datasets import Dataset, load_dataset
|
||||
from peft import LoraConfig
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, TrainingArguments
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, TrainingArguments, set_seed
|
||||
|
||||
from trl import DPOTrainer
|
||||
|
||||
@ -41,6 +42,10 @@ class ScriptArguments:
|
||||
default=True, metadata={"help": "whether to use gradient checkpointing"}
|
||||
)
|
||||
|
||||
gradient_checkpointing_use_reentrant: Optional[bool] = field(
|
||||
default=False, metadata={"help": "whether to use reentrant for gradient checkpointing"}
|
||||
)
|
||||
|
||||
lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"})
|
||||
lora_dropout: Optional[float] = field(default=0.05, metadata={"help": "the lora dropout parameter"})
|
||||
lora_r: Optional[int] = field(default=8, metadata={"help": "the lora r parameter"})
|
||||
@ -54,6 +59,10 @@ class ScriptArguments:
|
||||
|
||||
output_dir: Optional[str] = field(default="./results", metadata={"help": "the output directory"})
|
||||
log_freq: Optional[int] = field(default=1, metadata={"help": "the logging frequency"})
|
||||
load_in_4bit: Optional[bool] = field(default=True, metadata={"help": "whether to load the model in 4bit"})
|
||||
model_dtype: Optional[str] = field(
|
||||
default="float16", metadata={"help": "model_dtype[float16, bfloat16, float] for loading."}
|
||||
)
|
||||
|
||||
# instrumentation
|
||||
sanity_check: Optional[bool] = field(default=False, metadata={"help": "only train on 1000 samples"})
|
||||
@ -73,6 +82,9 @@ class ScriptArguments:
|
||||
"https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992"
|
||||
},
|
||||
)
|
||||
seed: Optional[int] = field(
|
||||
default=0, metadata={"help": "Random seed that will be set at the beginning of training."}
|
||||
)
|
||||
|
||||
|
||||
def get_stack_exchange_paired(
|
||||
@ -123,12 +135,21 @@ if __name__ == "__main__":
|
||||
parser = HfArgumentParser(ScriptArguments)
|
||||
script_args = parser.parse_args_into_dataclasses()[0]
|
||||
|
||||
set_seed(script_args.seed)
|
||||
|
||||
# 1. load a pretrained model
|
||||
torch_dtype = torch.float
|
||||
if script_args.model_dtype == "float16":
|
||||
torch_dtype = torch.float16
|
||||
elif script_args.model_dtype == "bfloat16":
|
||||
torch_dtype = torch.bfloat16
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
script_args.model_name_or_path,
|
||||
low_cpu_mem_usage=True,
|
||||
torch_dtype=torch.float16,
|
||||
load_in_4bit=True,
|
||||
torch_dtype=torch_dtype,
|
||||
load_in_4bit=script_args.load_in_4bit,
|
||||
device_map={"": Accelerator().local_process_index},
|
||||
)
|
||||
model.config.use_cache = False
|
||||
|
||||
@ -138,12 +159,6 @@ if __name__ == "__main__":
|
||||
name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool
|
||||
]
|
||||
|
||||
model_ref = AutoModelForCausalLM.from_pretrained(
|
||||
script_args.model_name_or_path,
|
||||
low_cpu_mem_usage=True,
|
||||
torch_dtype=torch.float16,
|
||||
load_in_4bit=True,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
@ -181,6 +196,8 @@ if __name__ == "__main__":
|
||||
bf16=True,
|
||||
remove_unused_columns=False,
|
||||
run_name="dpo_llama2",
|
||||
gradient_checkpointing_kwargs=dict(use_reentrant=script_args.gradient_checkpointing_use_reentrant),
|
||||
seed=script_args.seed,
|
||||
)
|
||||
|
||||
peft_config = LoraConfig(
|
||||
@ -203,7 +220,7 @@ if __name__ == "__main__":
|
||||
# 5. initialize the DPO trainer
|
||||
dpo_trainer = DPOTrainer(
|
||||
model,
|
||||
model_ref,
|
||||
ref_model=None,
|
||||
args=training_args,
|
||||
beta=script_args.beta,
|
||||
train_dataset=train_dataset,
|
||||
|
@ -8,7 +8,14 @@ from accelerate import Accelerator
|
||||
from datasets import load_dataset
|
||||
from peft import AutoPeftModelForCausalLM, LoraConfig
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, HfArgumentParser, TrainingArguments
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
BitsAndBytesConfig,
|
||||
HfArgumentParser,
|
||||
TrainingArguments,
|
||||
set_seed,
|
||||
)
|
||||
|
||||
from trl import SFTTrainer
|
||||
from trl.import_utils import is_npu_available, is_xpu_available
|
||||
@ -27,6 +34,7 @@ class ScriptArguments:
|
||||
seq_length: Optional[int] = field(default=1024, metadata={"help": "the sequence length"})
|
||||
num_workers: Optional[int] = field(default=4, metadata={"help": "the number of workers"})
|
||||
packing: Optional[bool] = field(default=True, metadata={"help": "whether to use packing for SFTTrainer"})
|
||||
use_bnb: Optional[bool] = field(default=True, metadata={"help": "whether to use BitsAndBytes"})
|
||||
|
||||
# LoraConfig
|
||||
lora_alpha: Optional[float] = field(default=16, metadata={"help": "the lora alpha parameter"})
|
||||
@ -53,6 +61,8 @@ if training_args.group_by_length and script_args.packing:
|
||||
if training_args.gradient_checkpointing:
|
||||
raise ValueError("gradient_checkpointing not supported")
|
||||
|
||||
set_seed(training_args.seed)
|
||||
|
||||
|
||||
def chars_token_ratio(dataset, tokenizer, nb_examples=400):
|
||||
"""
|
||||
@ -91,7 +101,7 @@ def prepare_sample_text(example):
|
||||
return text
|
||||
|
||||
|
||||
def create_datasets(tokenizer, args):
|
||||
def create_datasets(tokenizer, args, seed=None):
|
||||
dataset = load_dataset(
|
||||
args.dataset_name,
|
||||
data_dir=args.subset,
|
||||
@ -104,9 +114,9 @@ def create_datasets(tokenizer, args):
|
||||
print("Loading the dataset in streaming mode")
|
||||
valid_data = dataset.take(args.size_valid_set)
|
||||
train_data = dataset.skip(args.size_valid_set)
|
||||
train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=None)
|
||||
train_data = train_data.shuffle(buffer_size=args.shuffle_buffer, seed=seed)
|
||||
else:
|
||||
dataset = dataset.train_test_split(test_size=0.005, seed=None)
|
||||
dataset = dataset.train_test_split(test_size=0.005, seed=seed)
|
||||
train_data = dataset["train"]
|
||||
valid_data = dataset["test"]
|
||||
print(f"Size of the train set: {len(train_data)}. Size of the validation set: {len(valid_data)}")
|
||||
@ -133,11 +143,13 @@ def create_datasets(tokenizer, args):
|
||||
return train_dataset, valid_dataset
|
||||
|
||||
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||
)
|
||||
bnb_config = None
|
||||
if script_args.use_bnb:
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
base_model = AutoModelForCausalLM.from_pretrained(
|
||||
script_args.model_name,
|
||||
@ -153,7 +165,7 @@ tokenizer = AutoTokenizer.from_pretrained(script_args.model_name, trust_remote_c
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
tokenizer.padding_side = "right" # Fix weird overflow issue with fp16 training
|
||||
|
||||
train_dataset, eval_dataset = create_datasets(tokenizer, script_args)
|
||||
train_dataset, eval_dataset = create_datasets(tokenizer, script_args, seed=training_args.seed)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=base_model,
|
||||
|
340
examples/scripts/chat.py
Normal file
340
examples/scripts/chat.py
Normal file
@ -0,0 +1,340 @@
|
||||
# flake8: noqa
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from trl.commands.cli_utils import init_zero_verbose
|
||||
|
||||
init_zero_verbose()
|
||||
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import pwd
|
||||
import re
|
||||
import time
|
||||
from threading import Thread
|
||||
|
||||
import torch
|
||||
from rich.console import Console
|
||||
from rich.live import Live
|
||||
from rich.markdown import Markdown
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
||||
|
||||
from trl.commands.cli_utils import ChatArguments, TrlParser, init_zero_verbose
|
||||
from trl.trainer.utils import get_kbit_device_map, get_quantization_config
|
||||
|
||||
|
||||
HELP_STRING = """\
|
||||
|
||||
**TRL CHAT INTERFACE**
|
||||
|
||||
The chat interface is a simple tool to try out a chat model.
|
||||
|
||||
Besides talking to the model there are several commands:
|
||||
- **clear**: clears the current conversation and start a new one
|
||||
- **example {NAME}**: load example named `{NAME}` from the config and use it as the user input
|
||||
- **set {SETTING_NAME}={SETTING_VALUE};**: change the system prompt or generation settings (multiple settings are separated by a ';').
|
||||
- **reset**: same as clear but also resets the generation configs to defaults if they have been changed by **set**
|
||||
- **save {SAVE_NAME} (optional)**: save the current chat and settings to file by default to `./chat_history/{MODEL_NAME}/chat_{DATETIME}.yaml` or `{SAVE_NAME}` if provided
|
||||
- **exit**: closes the interface
|
||||
"""
|
||||
|
||||
SUPPORTED_GENERATION_KWARGS = [
|
||||
"max_new_tokens",
|
||||
"do_sample",
|
||||
"num_beams",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"top_k",
|
||||
"repetition_penalty",
|
||||
]
|
||||
|
||||
SETTING_RE = r"^set\s+[A-Za-z\s_]+=[A-Za-z\d\s.!\"#$%&'()*+,-/:<=>?@\[\]^_`{|}~]+(?:;\s*[A-Za-z\s_]+=[A-Za-z\d\s.!\"#$%&'()*+,-/:<=>?@\[\]^_`{|}~]+)*$"
|
||||
|
||||
|
||||
class RichInterface:
|
||||
def __init__(self, model_name=None, user_name=None):
|
||||
self._console = Console()
|
||||
if model_name is None:
|
||||
self.model_name = "assistant"
|
||||
else:
|
||||
self.model_name = model_name
|
||||
if user_name is None:
|
||||
self.user_name = "user"
|
||||
else:
|
||||
self.user_name = user_name
|
||||
|
||||
def stream_output(self, output_stream):
|
||||
"""Stream output from a role."""
|
||||
# This method is originally from the FastChat CLI: https://github.com/lm-sys/FastChat/blob/main/fastchat/serve/cli.py
|
||||
# Create a Live context for updating the console output
|
||||
text = ""
|
||||
self._console.print(f"[bold blue]<{self.model_name}>:")
|
||||
with Live(console=self._console, refresh_per_second=4) as live:
|
||||
# Read lines from the stream
|
||||
for i, outputs in enumerate(output_stream):
|
||||
if not outputs or i == 0:
|
||||
continue
|
||||
text += outputs
|
||||
# Render the accumulated text as Markdown
|
||||
# NOTE: this is a workaround for the rendering "unstandard markdown"
|
||||
# in rich. The chatbots output treat "\n" as a new line for
|
||||
# better compatibility with real-world text. However, rendering
|
||||
# in markdown would break the format. It is because standard markdown
|
||||
# treat a single "\n" in normal text as a space.
|
||||
# Our workaround is adding two spaces at the end of each line.
|
||||
# This is not a perfect solution, as it would
|
||||
# introduce trailing spaces (only) in code block, but it works well
|
||||
# especially for console output, because in general the console does not
|
||||
# care about trailing spaces.
|
||||
lines = []
|
||||
for line in text.splitlines():
|
||||
lines.append(line)
|
||||
if line.startswith("```"):
|
||||
# Code block marker - do not add trailing spaces, as it would
|
||||
# break the syntax highlighting
|
||||
lines.append("\n")
|
||||
else:
|
||||
lines.append(" \n")
|
||||
markdown = Markdown("".join(lines).strip(), code_theme="github-dark")
|
||||
# Update the Live console output
|
||||
live.update(markdown)
|
||||
self._console.print()
|
||||
return text
|
||||
|
||||
def input(self):
|
||||
input = self._console.input(f"[bold red]<{self.user_name}>:\n")
|
||||
self._console.print()
|
||||
return input
|
||||
|
||||
def clear(self):
|
||||
self._console.clear()
|
||||
|
||||
def print_user_message(self, text):
|
||||
self._console.print(f"[bold red]<{self.user_name}>:[/ bold red]\n{text}")
|
||||
self._console.print()
|
||||
|
||||
def print_green(self, text):
|
||||
self._console.print(f"[bold green]{text}")
|
||||
self._console.print()
|
||||
|
||||
def print_red(self, text):
|
||||
self._console.print(f"[bold red]{text}")
|
||||
self._console.print()
|
||||
|
||||
def print_help(self):
|
||||
self._console.print(Markdown(HELP_STRING))
|
||||
self._console.print()
|
||||
|
||||
|
||||
def get_username():
|
||||
return pwd.getpwuid(os.getuid())[0]
|
||||
|
||||
|
||||
def create_default_filename(model_name):
|
||||
time_str = time.strftime("%Y-%m-%d_%H-%M-%S")
|
||||
return f"{model_name}/chat_{time_str}.json"
|
||||
|
||||
|
||||
def save_chat(chat, args, filename):
|
||||
output_dict = {}
|
||||
output_dict["settings"] = vars(args)
|
||||
output_dict["chat_history"] = chat
|
||||
|
||||
folder = args.save_folder
|
||||
|
||||
if filename is None:
|
||||
filename = create_default_filename(args.model_name_or_path)
|
||||
filename = os.path.join(folder, filename)
|
||||
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
||||
|
||||
with open(filename, "w") as f:
|
||||
json.dump(output_dict, f, indent=4)
|
||||
return os.path.abspath(filename)
|
||||
|
||||
|
||||
def clear_chat_history(system_prompt):
|
||||
if system_prompt is None:
|
||||
chat = []
|
||||
else:
|
||||
chat = [{"role": "system", "content": system_prompt}]
|
||||
return chat
|
||||
|
||||
|
||||
def parse_settings(user_input, current_args, interface):
|
||||
settings = user_input[4:].strip().split(";")
|
||||
settings = [(setting.split("=")[0], setting[len(setting.split("=")[0]) + 1 :]) for setting in settings]
|
||||
settings = dict(settings)
|
||||
error = False
|
||||
|
||||
for name in settings:
|
||||
if hasattr(current_args, name):
|
||||
try:
|
||||
if isinstance(getattr(current_args, name), bool):
|
||||
if settings[name] == "True":
|
||||
settings[name] = True
|
||||
elif settings[name] == "False":
|
||||
settings[name] = False
|
||||
else:
|
||||
raise ValueError
|
||||
else:
|
||||
settings[name] = type(getattr(current_args, name))(settings[name])
|
||||
except ValueError:
|
||||
interface.print_red(
|
||||
f"Cannot cast setting {name} (={settings[name]}) to {type(getattr(current_args, name))}."
|
||||
)
|
||||
else:
|
||||
interface.print_red(f"There is no '{name}' setting.")
|
||||
|
||||
if error:
|
||||
interface.print_red("There was an issue parsing the settings. No settings have been changed.")
|
||||
return current_args, False
|
||||
else:
|
||||
for name in settings:
|
||||
setattr(current_args, name, settings[name])
|
||||
interface.print_green(f"Set {name} to {settings[name]}.")
|
||||
|
||||
time.sleep(1.5) # so the user has time to read the changes
|
||||
return current_args, True
|
||||
|
||||
|
||||
def load_model_and_tokenizer(args):
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, revision=args.model_revision)
|
||||
|
||||
torch_dtype = args.torch_dtype if args.torch_dtype in ["auto", None] else getattr(torch, args.torch_dtype)
|
||||
quantization_config = get_quantization_config(args)
|
||||
model_kwargs = dict(
|
||||
revision=args.model_revision,
|
||||
trust_remote_code=args.trust_remote_code,
|
||||
attn_implementation=args.attn_implementation,
|
||||
torch_dtype=torch_dtype,
|
||||
device_map=get_kbit_device_map() if quantization_config is not None else None,
|
||||
quantization_config=quantization_config,
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, **model_kwargs)
|
||||
|
||||
if getattr(model, "hf_device_map", None) is None:
|
||||
model = model.to(args.device)
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def chat_cli():
|
||||
parser = TrlParser(ChatArguments)
|
||||
args = parser.parse_args_into_dataclasses()[0]
|
||||
if args.config == "default":
|
||||
args.config = os.path.join(os.path.dirname(__file__), "config/default_chat_config.yaml")
|
||||
if args.config.lower() == "none":
|
||||
args.config = None
|
||||
args = parser.update_dataclasses_with_config([args])[0]
|
||||
if args.examples is None:
|
||||
args.examples = {}
|
||||
|
||||
current_args = copy.deepcopy(args)
|
||||
|
||||
if args.user is None:
|
||||
user = get_username()
|
||||
else:
|
||||
user = args.user
|
||||
|
||||
model, tokenizer = load_model_and_tokenizer(args)
|
||||
generation_streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
|
||||
|
||||
interface = RichInterface(model_name=args.model_name_or_path, user_name=user)
|
||||
interface.clear()
|
||||
chat = clear_chat_history(current_args.system_prompt)
|
||||
while True:
|
||||
try:
|
||||
user_input = interface.input()
|
||||
|
||||
if user_input == "clear":
|
||||
chat = clear_chat_history(current_args.system_prompt)
|
||||
interface.clear()
|
||||
continue
|
||||
|
||||
if user_input == "help":
|
||||
interface.print_help()
|
||||
continue
|
||||
|
||||
if user_input == "exit":
|
||||
break
|
||||
|
||||
if user_input == "reset":
|
||||
interface.clear()
|
||||
current_args = copy.deepcopy(args)
|
||||
chat = clear_chat_history(current_args.system_prompt)
|
||||
continue
|
||||
|
||||
if user_input.startswith("save") and len(user_input.split()) < 2:
|
||||
split_input = user_input.split()
|
||||
|
||||
if len(split_input) == 2:
|
||||
filename = split_input[1]
|
||||
else:
|
||||
filename = None
|
||||
filename = save_chat(chat, current_args, filename)
|
||||
interface.print_green(f"Chat saved in {filename}!")
|
||||
continue
|
||||
|
||||
if re.match(SETTING_RE, user_input):
|
||||
current_args, success = parse_settings(user_input, current_args, interface)
|
||||
if success:
|
||||
chat = []
|
||||
interface.clear()
|
||||
continue
|
||||
|
||||
if user_input.startswith("example") and len(user_input.split()) == 2:
|
||||
example_name = user_input.split()[1]
|
||||
if example_name in current_args.examples:
|
||||
interface.clear()
|
||||
chat = []
|
||||
interface.print_user_message(current_args.examples[example_name]["text"])
|
||||
user_input = current_args.examples[example_name]["text"]
|
||||
else:
|
||||
interface.print_red(
|
||||
f"Example {example_name} not found in list of available examples: {list(current_args.examples.keys())}."
|
||||
)
|
||||
continue
|
||||
|
||||
chat.append({"role": "user", "content": user_input})
|
||||
|
||||
generation_kwargs = dict(
|
||||
inputs=tokenizer.apply_chat_template(chat, return_tensors="pt", add_generation_prompt=True).to(
|
||||
model.device
|
||||
),
|
||||
streamer=generation_streamer,
|
||||
max_new_tokens=current_args.max_new_tokens,
|
||||
do_sample=current_args.do_sample,
|
||||
num_beams=current_args.num_beams,
|
||||
temperature=current_args.temperature,
|
||||
top_k=current_args.top_k,
|
||||
top_p=current_args.top_p,
|
||||
repetition_penalty=current_args.repetition_penalty,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
)
|
||||
|
||||
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
||||
thread.start()
|
||||
model_output = interface.stream_output(generation_streamer)
|
||||
thread.join()
|
||||
chat.append({"role": "assistant", "content": model_output})
|
||||
|
||||
except KeyboardInterrupt:
|
||||
break
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
chat_cli()
|
13
examples/scripts/config/default_chat_config.yaml
Normal file
13
examples/scripts/config/default_chat_config.yaml
Normal file
@ -0,0 +1,13 @@
|
||||
examples:
|
||||
llama:
|
||||
text: There is a Llama in my lawn, how can I get rid of it?
|
||||
code:
|
||||
text: Write a Python function that integrates any Python function f(x) numerically over an arbitrary interval [x_start, x_end].
|
||||
helicopter:
|
||||
text: How many helicopters can a human eat in one sitting?
|
||||
numbers:
|
||||
text: Count to 10 but skip every number ending with an 'e'
|
||||
birds:
|
||||
text: Why aren't birds real?
|
||||
socks:
|
||||
text: Why is it important to eat socks after meditating?
|
121
examples/scripts/cpo.py
Normal file
121
examples/scripts/cpo.py
Normal file
@ -0,0 +1,121 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Run the CPO training script with the following command with some example arguments.
|
||||
In general, the optimal configuration for CPO will be similar to that of DPO:
|
||||
|
||||
# regular:
|
||||
python examples/scripts/cpo.py \
|
||||
--model_name_or_path=gpt2 \
|
||||
--per_device_train_batch_size 4 \
|
||||
--max_steps 1000 \
|
||||
--learning_rate 8e-6 \
|
||||
--gradient_accumulation_steps 1 \
|
||||
--logging_steps 10 \
|
||||
--eval_steps 500 \
|
||||
--output_dir="gpt2-aligned-cpo" \
|
||||
--warmup_steps 150 \
|
||||
--report_to wandb \
|
||||
--bf16 \
|
||||
--logging_first_step \
|
||||
--no_remove_unused_columns
|
||||
|
||||
# peft:
|
||||
python examples/scripts/cpo.py \
|
||||
--model_name_or_path=gpt2 \
|
||||
--per_device_train_batch_size 4 \
|
||||
--max_steps 1000 \
|
||||
--learning_rate 8e-5 \
|
||||
--gradient_accumulation_steps 1 \
|
||||
--logging_steps 10 \
|
||||
--eval_steps 500 \
|
||||
--output_dir="gpt2-lora-aligned-cpo" \
|
||||
--optim rmsprop \
|
||||
--warmup_steps 150 \
|
||||
--report_to wandb \
|
||||
--bf16 \
|
||||
--logging_first_step \
|
||||
--no_remove_unused_columns \
|
||||
--use_peft \
|
||||
--lora_r=16 \
|
||||
--lora_alpha=16
|
||||
"""
|
||||
|
||||
import multiprocessing
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
|
||||
|
||||
from trl import CPOConfig, CPOTrainer, ModelConfig, get_peft_config
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
dataset: str = field(
|
||||
default="trl-internal-testing/hh-rlhf-trl-style", metadata={"help": "The name of the dataset to use."}
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = HfArgumentParser((ScriptArguments, CPOConfig, ModelConfig))
|
||||
args, cpo_args, model_config = parser.parse_args_into_dataclasses()
|
||||
|
||||
################
|
||||
# Model & Tokenizer
|
||||
################
|
||||
model = AutoModelForCausalLM.from_pretrained(model_config.model_name_or_path)
|
||||
peft_config = get_peft_config(model_config)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
################
|
||||
# Dataset
|
||||
################
|
||||
ds = load_dataset(args.dataset)
|
||||
if cpo_args.debug:
|
||||
for key in ds:
|
||||
ds[key] = ds[key].select(range(50))
|
||||
if tokenizer.chat_template is None:
|
||||
tokenizer.chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
|
||||
|
||||
def process(row):
|
||||
row["chosen"] = tokenizer.apply_chat_template(row["chosen"], tokenize=False)
|
||||
row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False)
|
||||
return row
|
||||
|
||||
ds = ds.map(
|
||||
process,
|
||||
num_proc=1 if cpo_args.debug else multiprocessing.cpu_count(),
|
||||
load_from_cache_file=False,
|
||||
)
|
||||
train_dataset = ds["train"]
|
||||
eval_dataset = ds["test"]
|
||||
|
||||
################
|
||||
# Training
|
||||
################
|
||||
trainer = CPOTrainer(
|
||||
model,
|
||||
args=cpo_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
tokenizer=tokenizer,
|
||||
peft_config=get_peft_config(model_config),
|
||||
)
|
||||
|
||||
# train and save the model
|
||||
trainer.train()
|
||||
trainer.save_model(cpo_args.output_dir)
|
@ -175,7 +175,7 @@ def image_outputs_logger(image_data, global_step, accelerate_logger):
|
||||
for i, image in enumerate(images):
|
||||
prompt = prompts[i]
|
||||
reward = rewards[i].item()
|
||||
result[f"{prompt:.25} | {reward:.2f}"] = image.unsqueeze(0)
|
||||
result[f"{prompt:.25} | {reward:.2f}"] = image.unsqueeze(0).float()
|
||||
|
||||
accelerate_logger.log_images(
|
||||
result,
|
||||
|
@ -1,3 +1,4 @@
|
||||
# flake8: noqa
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -14,9 +15,9 @@
|
||||
"""
|
||||
# regular:
|
||||
python examples/scripts/dpo.py \
|
||||
--dataset_name=trl-internal-testing/hh-rlhf-trl-style \
|
||||
--model_name_or_path=gpt2 \
|
||||
--per_device_train_batch_size 4 \
|
||||
--max_steps 1000 \
|
||||
--learning_rate 1e-3 \
|
||||
--gradient_accumulation_steps 1 \
|
||||
--logging_steps 10 \
|
||||
@ -30,9 +31,9 @@ python examples/scripts/dpo.py \
|
||||
|
||||
# peft:
|
||||
python examples/scripts/dpo.py \
|
||||
--dataset_name=trl-internal-testing/hh-rlhf-trl-style \
|
||||
--model_name_or_path=gpt2 \
|
||||
--per_device_train_batch_size 4 \
|
||||
--max_steps 1000 \
|
||||
--learning_rate 1e-3 \
|
||||
--gradient_accumulation_steps 1 \
|
||||
--logging_steps 10 \
|
||||
@ -48,76 +49,48 @@ python examples/scripts/dpo.py \
|
||||
--lora_r=16 \
|
||||
--lora_alpha=16
|
||||
"""
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Optional
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
|
||||
TRL_USE_RICH = os.environ.get("TRL_USE_RICH", False)
|
||||
|
||||
from trl.commands.cli_utils import DpoScriptArguments, init_zero_verbose, TrlParser
|
||||
|
||||
if TRL_USE_RICH:
|
||||
init_zero_verbose()
|
||||
FORMAT = "%(message)s"
|
||||
|
||||
from rich.console import Console
|
||||
from rich.logging import RichHandler
|
||||
|
||||
import torch
|
||||
from datasets import Dataset, load_dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, TrainingArguments
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
|
||||
|
||||
from trl import DPOTrainer, ModelConfig, get_kbit_device_map, get_peft_config, get_quantization_config
|
||||
from trl import (
|
||||
DPOTrainer,
|
||||
ModelConfig,
|
||||
RichProgressCallback,
|
||||
get_kbit_device_map,
|
||||
get_peft_config,
|
||||
get_quantization_config,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
beta: float = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"})
|
||||
max_length: int = field(default=512, metadata={"help": "max length of each sample"})
|
||||
max_prompt_length: int = field(default=128, metadata={"help": "max length of each sample's prompt"})
|
||||
max_target_length: int = field(
|
||||
default=128, metadata={"help": "Only used for encoder decoder model. Max target of each sample's prompt"}
|
||||
)
|
||||
sanity_check: bool = field(default=True, metadata={"help": "only train on 1000 samples"})
|
||||
ignore_bias_buffers: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "debug argument for distributed training;"
|
||||
"fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See"
|
||||
"https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992"
|
||||
},
|
||||
)
|
||||
generate_during_eval: bool = field(default=False, metadata={"help": "Generate during evaluation"})
|
||||
|
||||
|
||||
def extract_anthropic_prompt(prompt_and_response):
|
||||
"""Extract the anthropic prompt from a prompt and response pair."""
|
||||
search_term = "\n\nAssistant:"
|
||||
search_term_idx = prompt_and_response.rfind(search_term)
|
||||
assert search_term_idx != -1, f"Prompt and response does not contain '{search_term}'"
|
||||
return prompt_and_response[: search_term_idx + len(search_term)]
|
||||
|
||||
|
||||
def get_hh(split: str, sanity_check: bool = False, silent: bool = False, cache_dir: Optional[str] = None) -> Dataset:
|
||||
"""Load the Anthropic Helpful-Harmless dataset from Hugging Face and convert it to the necessary format.
|
||||
|
||||
The dataset is converted to a dictionary with the following structure:
|
||||
{
|
||||
'prompt': List[str],
|
||||
'chosen': List[str],
|
||||
'rejected': List[str],
|
||||
}
|
||||
|
||||
Prompts should be structured as follows:
|
||||
\n\nHuman: <prompt>\n\nAssistant:
|
||||
Multiple turns are allowed, but the prompt should always start with \n\nHuman: and end with \n\nAssistant:.
|
||||
"""
|
||||
dataset = load_dataset("Anthropic/hh-rlhf", split=split, cache_dir=cache_dir)
|
||||
if sanity_check:
|
||||
dataset = dataset.select(range(min(len(dataset), 1000)))
|
||||
|
||||
def split_prompt_and_responses(sample) -> Dict[str, str]:
|
||||
prompt = extract_anthropic_prompt(sample["chosen"])
|
||||
return {
|
||||
"prompt": prompt,
|
||||
"chosen": sample["chosen"][len(prompt) :],
|
||||
"rejected": sample["rejected"][len(prompt) :],
|
||||
}
|
||||
|
||||
return dataset.map(split_prompt_and_responses)
|
||||
if TRL_USE_RICH:
|
||||
logging.basicConfig(format=FORMAT, datefmt="[%X]", handlers=[RichHandler()], level=logging.INFO)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = HfArgumentParser((ScriptArguments, TrainingArguments, ModelConfig))
|
||||
args, training_args, model_config = parser.parse_args_into_dataclasses()
|
||||
parser = TrlParser((DpoScriptArguments, TrainingArguments, ModelConfig))
|
||||
args, training_args, model_config = parser.parse_args_and_config()
|
||||
|
||||
# Force use our print callback
|
||||
if TRL_USE_RICH:
|
||||
training_args.disable_tqdm = True
|
||||
console = Console()
|
||||
|
||||
################
|
||||
# Model & Tokenizer
|
||||
@ -146,34 +119,66 @@ if __name__ == "__main__":
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
if tokenizer.chat_template is None:
|
||||
tokenizer.chat_template = "{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\n\n'}}{% endfor %}{{ eos_token }}"
|
||||
if args.ignore_bias_buffers:
|
||||
# torch distributed hack
|
||||
model._ddp_params_and_buffers_to_ignore = [
|
||||
name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool
|
||||
]
|
||||
|
||||
################
|
||||
# Optional rich context managers
|
||||
###############
|
||||
init_context = nullcontext() if not TRL_USE_RICH else console.status("[bold green]Initializing the DPOTrainer...")
|
||||
save_context = (
|
||||
nullcontext()
|
||||
if not TRL_USE_RICH
|
||||
else console.status(f"[bold green]Training completed! Saving the model to {training_args.output_dir}")
|
||||
)
|
||||
|
||||
################
|
||||
# Dataset
|
||||
################
|
||||
train_dataset = get_hh("train", sanity_check=args.sanity_check)
|
||||
eval_dataset = get_hh("test", sanity_check=args.sanity_check)
|
||||
ds = load_dataset(args.dataset_name)
|
||||
if args.sanity_check:
|
||||
for key in ds:
|
||||
ds[key] = ds[key].select(range(50))
|
||||
|
||||
def process(row):
|
||||
row["chosen"] = tokenizer.apply_chat_template(row["chosen"], tokenize=False)
|
||||
row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False)
|
||||
return row
|
||||
|
||||
ds = ds.map(
|
||||
process,
|
||||
num_proc=multiprocessing.cpu_count(),
|
||||
load_from_cache_file=False,
|
||||
)
|
||||
train_dataset = ds["train"]
|
||||
eval_dataset = ds["test"]
|
||||
|
||||
################
|
||||
# Training
|
||||
################
|
||||
trainer = DPOTrainer(
|
||||
model,
|
||||
model_ref,
|
||||
args=training_args,
|
||||
beta=args.beta,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
tokenizer=tokenizer,
|
||||
max_length=args.max_length,
|
||||
max_target_length=args.max_target_length,
|
||||
max_prompt_length=args.max_prompt_length,
|
||||
generate_during_eval=args.generate_during_eval,
|
||||
peft_config=get_peft_config(model_config),
|
||||
)
|
||||
with init_context:
|
||||
trainer = DPOTrainer(
|
||||
model,
|
||||
model_ref,
|
||||
args=training_args,
|
||||
beta=args.beta,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
tokenizer=tokenizer,
|
||||
max_length=args.max_length,
|
||||
max_target_length=args.max_target_length,
|
||||
max_prompt_length=args.max_prompt_length,
|
||||
generate_during_eval=args.generate_during_eval,
|
||||
peft_config=get_peft_config(model_config),
|
||||
callbacks=[RichProgressCallback] if TRL_USE_RICH else None,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
trainer.save_model(training_args.output_dir)
|
||||
|
||||
with save_context:
|
||||
trainer.save_model(training_args.output_dir)
|
||||
|
115
examples/scripts/kto.py
Normal file
115
examples/scripts/kto.py
Normal file
@ -0,0 +1,115 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Run the KTO training script with the commands below. In general, the optimal configuration for KTO will be similar to that of DPO.
|
||||
|
||||
# Full training:
|
||||
python examples/scripts/kto.py \
|
||||
--model_name_or_path=trl-lib/qwen1.5-1.8b-sft \
|
||||
--per_device_train_batch_size 16 \
|
||||
--num_train_epochs 1 \
|
||||
--learning_rate 1e-5 \
|
||||
--lr_scheduler_type=cosine \
|
||||
--gradient_accumulation_steps 1 \
|
||||
--logging_steps 10 \
|
||||
--eval_steps 500 \
|
||||
--output_dir=kto-aligned-model \
|
||||
--warmup_ratio 0.1 \
|
||||
--report_to wandb \
|
||||
--bf16 \
|
||||
--logging_first_step
|
||||
|
||||
# QLoRA:
|
||||
python examples/scripts/kto.py \
|
||||
--model_name_or_path=trl-lib/qwen1.5-1.8b-sft \
|
||||
--per_device_train_batch_size 8 \
|
||||
--num_train_epochs 1 \
|
||||
--learning_rate 1e-4 \
|
||||
--lr_scheduler_type=cosine \
|
||||
--gradient_accumulation_steps 1 \
|
||||
--logging_steps 10 \
|
||||
--eval_steps 500 \
|
||||
--output_dir=kto-aligned-model-lora \
|
||||
--warmup_ratio 0.1 \
|
||||
--report_to wandb \
|
||||
--bf16 \
|
||||
--logging_first_step \
|
||||
--use_peft \
|
||||
--load_in_4bit \
|
||||
--lora_target_modules=all-linear \
|
||||
--lora_r=16 \
|
||||
--lora_alpha=16
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
|
||||
|
||||
from trl import KTOConfig, KTOTrainer, ModelConfig, get_peft_config, setup_chat_format
|
||||
|
||||
|
||||
# Define and parse arguments.
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
"""
|
||||
The arguments for the KTO training script.
|
||||
"""
|
||||
|
||||
dataset_name: str = "trl-lib/kto-mix-14k"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = HfArgumentParser((ScriptArguments, KTOConfig, ModelConfig))
|
||||
script_args, kto_args, model_args = parser.parse_args_into_dataclasses()
|
||||
|
||||
# Load a pretrained model
|
||||
model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path)
|
||||
model_ref = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
# If we are aligning a base model, we use ChatML as the default template
|
||||
if tokenizer.chat_template is None:
|
||||
model, tokenizer = setup_chat_format(model, tokenizer)
|
||||
|
||||
# Load the dataset
|
||||
dataset = load_dataset(script_args.dataset_name)
|
||||
|
||||
# Apply chat template
|
||||
def format_dataset(example):
|
||||
example["prompt"] = tokenizer.apply_chat_template(example["prompt"], tokenize=False)
|
||||
example["completion"] = tokenizer.apply_chat_template(example["completion"], tokenize=False)
|
||||
return example
|
||||
|
||||
formatted_dataset = dataset.map(format_dataset)
|
||||
|
||||
# Initialize the KTO trainer
|
||||
kto_trainer = KTOTrainer(
|
||||
model,
|
||||
model_ref,
|
||||
args=kto_args,
|
||||
train_dataset=formatted_dataset["train"],
|
||||
eval_dataset=formatted_dataset["test"],
|
||||
tokenizer=tokenizer,
|
||||
peft_config=get_peft_config(model_args),
|
||||
)
|
||||
|
||||
# Train and push the model to the Hub
|
||||
kto_trainer.train()
|
||||
kto_trainer.save_model(kto_args.output_dir)
|
||||
kto_trainer.push_to_hub()
|
121
examples/scripts/orpo.py
Normal file
121
examples/scripts/orpo.py
Normal file
@ -0,0 +1,121 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Run the ORPO training script with the following command with some example arguments.
|
||||
In general, the optimal configuration for ORPO will be similar to that of DPO without the need for a reference model:
|
||||
|
||||
# regular:
|
||||
python examples/scripts/orpo.py \
|
||||
--model_name_or_path=gpt2 \
|
||||
--per_device_train_batch_size 4 \
|
||||
--max_steps 1000 \
|
||||
--learning_rate 8e-6 \
|
||||
--gradient_accumulation_steps 1 \
|
||||
--logging_steps 10 \
|
||||
--eval_steps 500 \
|
||||
--output_dir="gpt2-aligned-orpo" \
|
||||
--warmup_steps 150 \
|
||||
--report_to wandb \
|
||||
--bf16 \
|
||||
--logging_first_step \
|
||||
--no_remove_unused_columns
|
||||
|
||||
# peft:
|
||||
python examples/scripts/orpo.py \
|
||||
--model_name_or_path=gpt2 \
|
||||
--per_device_train_batch_size 4 \
|
||||
--max_steps 1000 \
|
||||
--learning_rate 8e-5 \
|
||||
--gradient_accumulation_steps 1 \
|
||||
--logging_steps 10 \
|
||||
--eval_steps 500 \
|
||||
--output_dir="gpt2-lora-aligned-orpo" \
|
||||
--optim rmsprop \
|
||||
--warmup_steps 150 \
|
||||
--report_to wandb \
|
||||
--bf16 \
|
||||
--logging_first_step \
|
||||
--no_remove_unused_columns \
|
||||
--use_peft \
|
||||
--lora_r=16 \
|
||||
--lora_alpha=16
|
||||
"""
|
||||
|
||||
import multiprocessing
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
|
||||
|
||||
from trl import ModelConfig, ORPOConfig, ORPOTrainer, get_peft_config
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
dataset: str = field(
|
||||
default="trl-internal-testing/hh-rlhf-trl-style", metadata={"help": "The name of the dataset to use."}
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = HfArgumentParser((ScriptArguments, ORPOConfig, ModelConfig))
|
||||
args, orpo_args, model_config = parser.parse_args_into_dataclasses()
|
||||
|
||||
################
|
||||
# Model & Tokenizer
|
||||
################
|
||||
model = AutoModelForCausalLM.from_pretrained(model_config.model_name_or_path)
|
||||
peft_config = get_peft_config(model_config)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
################
|
||||
# Dataset
|
||||
################
|
||||
ds = load_dataset(args.dataset)
|
||||
if orpo_args.debug:
|
||||
for key in ds:
|
||||
ds[key] = ds[key].select(range(50))
|
||||
if tokenizer.chat_template is None:
|
||||
tokenizer.chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
|
||||
|
||||
def process(row):
|
||||
row["chosen"] = tokenizer.apply_chat_template(row["chosen"], tokenize=False)
|
||||
row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False)
|
||||
return row
|
||||
|
||||
ds = ds.map(
|
||||
process,
|
||||
num_proc=1 if orpo_args.debug else multiprocessing.cpu_count(),
|
||||
load_from_cache_file=False,
|
||||
)
|
||||
train_dataset = ds["train"]
|
||||
eval_dataset = ds["test"]
|
||||
|
||||
################
|
||||
# Training
|
||||
################
|
||||
trainer = ORPOTrainer(
|
||||
model,
|
||||
args=orpo_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
tokenizer=tokenizer,
|
||||
peft_config=get_peft_config(model_config),
|
||||
)
|
||||
|
||||
# train and save the model
|
||||
trainer.train()
|
||||
trainer.save_model(orpo_args.output_dir)
|
@ -27,6 +27,7 @@ python examples/scripts/reward_modeling.py \
|
||||
--evaluation_strategy="steps" \
|
||||
--max_length=512 \
|
||||
"""
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
@ -64,6 +65,12 @@ if __name__ == "__main__":
|
||||
model_config.model_name_or_path, num_labels=1, **model_kwargs
|
||||
)
|
||||
|
||||
if model_config.lora_task_type != "SEQ_CLS":
|
||||
warnings.warn(
|
||||
"You are using a `task_type` that is different than `SEQ_CLS` for PEFT. This will lead to silent bugs"
|
||||
" Make sure to pass --lora_task_type SEQ_CLS when using this script."
|
||||
)
|
||||
|
||||
################
|
||||
# Dataset
|
||||
################
|
||||
|
@ -1,3 +1,4 @@
|
||||
# flake8: noqa
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -43,30 +44,50 @@ python examples/scripts/sft.py \
|
||||
--lora_r=64 \
|
||||
--lora_alpha=16
|
||||
"""
|
||||
from dataclasses import dataclass, field
|
||||
import logging
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
|
||||
TRL_USE_RICH = os.environ.get("TRL_USE_RICH", False)
|
||||
|
||||
from trl.commands.cli_utils import init_zero_verbose, SftScriptArguments, TrlParser
|
||||
|
||||
if TRL_USE_RICH:
|
||||
init_zero_verbose()
|
||||
FORMAT = "%(message)s"
|
||||
|
||||
from rich.console import Console
|
||||
from rich.logging import RichHandler
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoTokenizer, HfArgumentParser, TrainingArguments
|
||||
|
||||
from trl import ModelConfig, SFTTrainer, get_kbit_device_map, get_peft_config, get_quantization_config
|
||||
from tqdm.rich import tqdm
|
||||
from transformers import AutoTokenizer, TrainingArguments
|
||||
|
||||
from trl import (
|
||||
ModelConfig,
|
||||
RichProgressCallback,
|
||||
SFTTrainer,
|
||||
get_peft_config,
|
||||
get_quantization_config,
|
||||
get_kbit_device_map,
|
||||
)
|
||||
|
||||
tqdm.pandas()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScriptArguments:
|
||||
dataset_name: str = field(default="timdettmers/openassistant-guanaco", metadata={"help": "the dataset name"})
|
||||
dataset_text_field: str = field(default="text", metadata={"help": "the text field of the dataset"})
|
||||
max_seq_length: int = field(default=512, metadata={"help": "The maximum sequence length for SFT Trainer"})
|
||||
if TRL_USE_RICH:
|
||||
logging.basicConfig(format=FORMAT, datefmt="[%X]", handlers=[RichHandler()], level=logging.INFO)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = HfArgumentParser((ScriptArguments, TrainingArguments, ModelConfig))
|
||||
args, training_args, model_config = parser.parse_args_into_dataclasses()
|
||||
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
|
||||
parser = TrlParser((SftScriptArguments, TrainingArguments, ModelConfig))
|
||||
args, training_args, model_config = parser.parse_args_and_config()
|
||||
|
||||
# Force use our print callback
|
||||
if TRL_USE_RICH:
|
||||
training_args.disable_tqdm = True
|
||||
console = Console()
|
||||
|
||||
################
|
||||
# Model & Tokenizer
|
||||
@ -96,20 +117,35 @@ if __name__ == "__main__":
|
||||
train_dataset = raw_datasets["train"]
|
||||
eval_dataset = raw_datasets["test"]
|
||||
|
||||
################
|
||||
# Optional rich context managers
|
||||
###############
|
||||
init_context = nullcontext() if not TRL_USE_RICH else console.status("[bold green]Initializing the SFTTrainer...")
|
||||
save_context = (
|
||||
nullcontext()
|
||||
if not TRL_USE_RICH
|
||||
else console.status(f"[bold green]Training completed! Saving the model to {training_args.output_dir}")
|
||||
)
|
||||
|
||||
################
|
||||
# Training
|
||||
################
|
||||
trainer = SFTTrainer(
|
||||
model=model_config.model_name_or_path,
|
||||
model_init_kwargs=model_kwargs,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
dataset_text_field="text",
|
||||
max_seq_length=args.max_seq_length,
|
||||
tokenizer=tokenizer,
|
||||
packing=True,
|
||||
peft_config=get_peft_config(model_config),
|
||||
)
|
||||
with init_context:
|
||||
trainer = SFTTrainer(
|
||||
model=model_config.model_name_or_path,
|
||||
model_init_kwargs=model_kwargs,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
dataset_text_field=args.dataset_text_field,
|
||||
max_seq_length=args.max_seq_length,
|
||||
tokenizer=tokenizer,
|
||||
packing=args.packing,
|
||||
peft_config=get_peft_config(model_config),
|
||||
callbacks=[RichProgressCallback] if TRL_USE_RICH else None,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
trainer.save_model(training_args.output_dir)
|
||||
|
||||
with save_context:
|
||||
trainer.save_model(training_args.output_dir)
|
||||
|
210
examples/scripts/vsft_llava.py
Normal file
210
examples/scripts/vsft_llava.py
Normal file
@ -0,0 +1,210 @@
|
||||
# flake8: noqa
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
# regular:
|
||||
python examples/scripts/vsft.py \
|
||||
--model_name_or_path="llava-hf/llava-1.5-7b-hf" \
|
||||
--report_to="wandb" \
|
||||
--learning_rate=1.4e-5 \
|
||||
--per_device_train_batch_size=8 \
|
||||
--gradient_accumulation_steps=1 \
|
||||
--output_dir="data/vsft-llava-1.5-7b-hf" \
|
||||
--logging_steps=5 \
|
||||
--num_train_epochs=1 \
|
||||
--push_to_hub \
|
||||
--gradient_checkpointing \
|
||||
--remove_unused_columns=False \
|
||||
--torch_dtype=float16 \
|
||||
--fp16=True \
|
||||
--dataset_name=HuggingFaceH4/llava-instruct-mix-vsft \
|
||||
|
||||
# peft:
|
||||
python examples/scripts/vsft.py \
|
||||
--model_name_or_path="llava-hf/llava-1.5-7b-hf" \
|
||||
--report_to="wandb" \
|
||||
--learning_rate=1.4e-5 \
|
||||
--per_device_train_batch_size=8 \
|
||||
--gradient_accumulation_steps=1 \
|
||||
--output_dir="data/vsft-llava-1.5-7b-hf" \
|
||||
--logging_steps=5 \
|
||||
--num_train_epochs=1 \
|
||||
--push_to_hub \
|
||||
--gradient_checkpointing \
|
||||
--remove_unused_columns=False \
|
||||
--torch_dtype=float16 \
|
||||
--fp16=True \
|
||||
--dataset_name=HuggingFaceH4/llava-instruct-mix-vsft \
|
||||
--use_peft=True \
|
||||
--lora_r=64 \
|
||||
--lora_alpha=16 \
|
||||
--lora_target_modules=all-linear"
|
||||
|
||||
# evaluation:
|
||||
|
||||
To evaluate, first install the lmms-eval framework: pip install git+https://github.com/EvolvingLMMs-Lab/lmms-eval.git
|
||||
then run:
|
||||
accelerate launch --num_processes=8 -m lmms_eval \
|
||||
--model llava_hf \
|
||||
--model_args pretrained=llava-hf/llava-1.5-7b-hf \
|
||||
--tasks mmbench \
|
||||
--batch_size 1 \
|
||||
--output_path ./logs/ \
|
||||
--log_sample
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
|
||||
TRL_USE_RICH = os.environ.get("TRL_USE_RICH", False)
|
||||
|
||||
from trl.commands.cli_utils import init_zero_verbose, SftScriptArguments, TrlParser
|
||||
|
||||
if TRL_USE_RICH:
|
||||
init_zero_verbose()
|
||||
FORMAT = "%(message)s"
|
||||
|
||||
from rich.console import Console
|
||||
from rich.logging import RichHandler
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from datasets import load_dataset
|
||||
|
||||
from tqdm.rich import tqdm
|
||||
from transformers import AutoTokenizer, AutoProcessor, TrainingArguments, LlavaForConditionalGeneration
|
||||
|
||||
from trl import (
|
||||
ModelConfig,
|
||||
RichProgressCallback,
|
||||
SFTTrainer,
|
||||
get_peft_config,
|
||||
get_quantization_config,
|
||||
get_kbit_device_map,
|
||||
)
|
||||
|
||||
tqdm.pandas()
|
||||
|
||||
if TRL_USE_RICH:
|
||||
logging.basicConfig(format=FORMAT, datefmt="[%X]", handlers=[RichHandler()], level=logging.INFO)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = TrlParser((SftScriptArguments, TrainingArguments, ModelConfig))
|
||||
args, training_args, model_config = parser.parse_args_and_config()
|
||||
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
|
||||
# Force use our print callback
|
||||
if TRL_USE_RICH:
|
||||
training_args.disable_tqdm = True
|
||||
console = Console()
|
||||
|
||||
################
|
||||
# Model, Tokenizer & Processor
|
||||
################
|
||||
LLAVA_CHAT_TEMPLATE = """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. {% for message in messages %}{% if message['role'] == 'user' %}USER: {% else %}ASSISTANT: {% endif %}{% for item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% elif item['type'] == 'image' %}<image>{% endif %}{% endfor %}{% if message['role'] == 'user' %} {% else %}{{eos_token}}{% endif %}{% endfor %}{% if add_generation_prompt %}ASSISTANT: {% endif %}"""
|
||||
|
||||
torch_dtype = (
|
||||
model_config.torch_dtype
|
||||
if model_config.torch_dtype in ["auto", None]
|
||||
else getattr(torch, model_config.torch_dtype)
|
||||
)
|
||||
quantization_config = get_quantization_config(model_config)
|
||||
model_kwargs = dict(
|
||||
revision=model_config.model_revision,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
attn_implementation=model_config.attn_implementation,
|
||||
torch_dtype=torch_dtype,
|
||||
device_map=get_kbit_device_map() if quantization_config is not None else None,
|
||||
quantization_config=quantization_config,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path, use_fast=True)
|
||||
tokenizer.chat_template = LLAVA_CHAT_TEMPLATE
|
||||
processor = AutoProcessor.from_pretrained(model_config.model_name_or_path)
|
||||
processor.tokenizer = tokenizer
|
||||
|
||||
model = LlavaForConditionalGeneration.from_pretrained(model_config.model_name_or_path, **model_kwargs)
|
||||
|
||||
################
|
||||
# Create a data collator to encode text and image pairs
|
||||
################
|
||||
|
||||
class LLavaDataCollator:
|
||||
def __init__(self, processor):
|
||||
self.processor = processor
|
||||
|
||||
def __call__(self, examples):
|
||||
texts = []
|
||||
images = []
|
||||
for example in examples:
|
||||
if len(example["images"]) > 1:
|
||||
raise ValueError("This collator only supports one image per example")
|
||||
messages = example["messages"]
|
||||
text = self.processor.tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=False
|
||||
)
|
||||
texts.append(text)
|
||||
images.append(example["images"][0])
|
||||
|
||||
batch = self.processor(texts, images, return_tensors="pt", padding=True)
|
||||
|
||||
labels = batch["input_ids"].clone()
|
||||
if self.processor.tokenizer.pad_token_id is not None:
|
||||
labels[labels == self.processor.tokenizer.pad_token_id] = -100
|
||||
batch["labels"] = labels
|
||||
|
||||
return batch
|
||||
|
||||
data_collator = LLavaDataCollator(processor)
|
||||
|
||||
################
|
||||
# Dataset
|
||||
################
|
||||
raw_datasets = load_dataset(args.dataset_name)
|
||||
train_dataset = raw_datasets["train"]
|
||||
eval_dataset = raw_datasets["test"]
|
||||
|
||||
################
|
||||
# Optional rich context managers
|
||||
###############
|
||||
init_context = nullcontext() if not TRL_USE_RICH else console.status("[bold green]Initializing the SFTTrainer...")
|
||||
save_context = (
|
||||
nullcontext()
|
||||
if not TRL_USE_RICH
|
||||
else console.status(f"[bold green]Training completed! Saving the model to {training_args.output_dir}")
|
||||
)
|
||||
|
||||
################
|
||||
# Training
|
||||
################
|
||||
with init_context:
|
||||
trainer = SFTTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
dataset_text_field="text", # need a dummy field
|
||||
tokenizer=tokenizer,
|
||||
peft_config=get_peft_config(model_config),
|
||||
callbacks=[RichProgressCallback] if TRL_USE_RICH else None,
|
||||
data_collator=data_collator,
|
||||
dataset_kwargs={"skip_prepare_dataset": True},
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
||||
with save_context:
|
||||
trainer.save_model(training_args.output_dir)
|
||||
trainer.push_to_hub()
|
||||
if Accelerator().is_main_process:
|
||||
processor.push_to_hub(training_args.hub_model_id)
|
75
setup.py
75
setup.py
@ -53,11 +53,12 @@ To create the package for pypi.
|
||||
8. Change the version in __init__.py and setup.py to X.X.X+1.dev0 (e.g. VERSION=1.18.3 -> 1.18.4.dev0).
|
||||
Then push the change with a message 'set dev version'
|
||||
"""
|
||||
import os
|
||||
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
|
||||
__version__ = "0.7.11" # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
__version__ = "0.8.4" # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
|
||||
REQUIRED_PKGS = [
|
||||
"torch>=1.4.0",
|
||||
@ -79,34 +80,44 @@ EXTRAS["dev"] = []
|
||||
for reqs in EXTRAS.values():
|
||||
EXTRAS["dev"].extend(reqs)
|
||||
|
||||
setup(
|
||||
name="trl",
|
||||
license="Apache 2.0",
|
||||
classifiers=[
|
||||
"Development Status :: 2 - Pre-Alpha",
|
||||
"Intended Audience :: Developers",
|
||||
"Intended Audience :: Science/Research",
|
||||
"License :: OSI Approved :: Apache Software License",
|
||||
"Natural Language :: English",
|
||||
"Operating System :: OS Independent",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.7",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
],
|
||||
url="https://github.com/huggingface/trl",
|
||||
packages=find_packages(),
|
||||
include_package_data=True,
|
||||
install_requires=REQUIRED_PKGS,
|
||||
extras_require=EXTRAS,
|
||||
python_requires=">=3.7",
|
||||
long_description=open("README.md", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
zip_safe=False,
|
||||
version=__version__,
|
||||
description="Train transformer language models with reinforcement learning.",
|
||||
keywords="ppo, transformers, huggingface, gpt2, language modeling, rlhf",
|
||||
author="Leandro von Werra",
|
||||
author_email="leandro.vonwerra@gmail.com",
|
||||
)
|
||||
try:
|
||||
file_path = os.path.dirname(os.path.abspath(__file__))
|
||||
os.symlink(os.path.join(file_path, "examples/scripts"), os.path.join(file_path, "trl/commands/scripts"))
|
||||
|
||||
setup(
|
||||
name="trl",
|
||||
license="Apache 2.0",
|
||||
classifiers=[
|
||||
"Development Status :: 2 - Pre-Alpha",
|
||||
"Intended Audience :: Developers",
|
||||
"Intended Audience :: Science/Research",
|
||||
"License :: OSI Approved :: Apache Software License",
|
||||
"Natural Language :: English",
|
||||
"Operating System :: OS Independent",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.7",
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
],
|
||||
url="https://github.com/huggingface/trl",
|
||||
entry_points={
|
||||
"console_scripts": ["trl=trl.commands.cli:main"],
|
||||
},
|
||||
include_package_data=True,
|
||||
package_data={"trl": ["commands/scripts/config/*", "commands/scripts/*"]},
|
||||
packages=find_packages(),
|
||||
install_requires=REQUIRED_PKGS,
|
||||
extras_require=EXTRAS,
|
||||
python_requires=">=3.7",
|
||||
long_description=open("README.md", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
zip_safe=False,
|
||||
version=__version__,
|
||||
description="Train transformer language models with reinforcement learning.",
|
||||
keywords="ppo, transformers, huggingface, gpt2, language modeling, rlhf",
|
||||
author="Leandro von Werra",
|
||||
author_email="leandro.vonwerra@gmail.com",
|
||||
)
|
||||
finally:
|
||||
os.unlink(os.path.join(file_path, "trl/commands/scripts"))
|
||||
|
40
tests/test_cli.py
Normal file
40
tests/test_cli.py
Normal file
@ -0,0 +1,40 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import subprocess
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
|
||||
@unittest.skipIf(sys.platform.startswith("win"), "Skipping on Windows")
|
||||
def test_sft_cli():
|
||||
try:
|
||||
subprocess.run(
|
||||
"trl sft --max_steps 1 --output_dir tmp-sft --model_name_or_path HuggingFaceM4/tiny-random-LlamaForCausalLM --dataset_name imdb --learning_rate 1e-4 --lr_scheduler_type cosine --dataset_text_field text",
|
||||
shell=True,
|
||||
check=True,
|
||||
)
|
||||
except BaseException as exc:
|
||||
raise AssertionError("An error occured while running the CLI, please double check") from exc
|
||||
|
||||
|
||||
@unittest.skipIf(sys.platform.startswith("win"), "Skipping on Windows")
|
||||
def test_dpo_cli():
|
||||
try:
|
||||
subprocess.run(
|
||||
"trl dpo --max_steps 1 --output_dir tmp-sft --model_name_or_path HuggingFaceM4/tiny-random-LlamaForCausalLM --dataset_name trl-internal-testing/hh-rlhf-trl-style --learning_rate 1e-4 --lr_scheduler_type cosine --sanity_check",
|
||||
shell=True,
|
||||
check=True,
|
||||
)
|
||||
except BaseException as exc:
|
||||
raise AssertionError("An error occured while running the CLI, please double check") from exc
|
182
tests/test_cpo_trainer.py
Normal file
182
tests/test_cpo_trainer.py
Normal file
@ -0,0 +1,182 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from parameterized import parameterized
|
||||
from pytest import mark
|
||||
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer
|
||||
|
||||
from trl import CPOConfig, CPOTrainer
|
||||
|
||||
from .testing_utils import require_peft
|
||||
|
||||
|
||||
class CPOTrainerTester(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model_id = "trl-internal-testing/dummy-GPT2-correct-vocab"
|
||||
cls.model = AutoModelForCausalLM.from_pretrained(cls.model_id)
|
||||
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_id)
|
||||
cls.tokenizer.pad_token = cls.tokenizer.eos_token
|
||||
|
||||
# get t5 as seq2seq example:
|
||||
model_id = "trl-internal-testing/tiny-T5ForConditionalGeneration-correct-vocab"
|
||||
cls.t5_model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
|
||||
cls.t5_tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
def _init_dummy_dataset(self):
|
||||
# fmt: off
|
||||
dummy_dataset_dict = {
|
||||
"prompt": [
|
||||
"hello",
|
||||
"how are you",
|
||||
"What is your name?",
|
||||
"What is your name?",
|
||||
"Which is the best programming language?",
|
||||
"Which is the best programming language?",
|
||||
"Which is the best programming language?",
|
||||
"[INST] How is the stock price? [/INST]",
|
||||
"[INST] How is the stock price? [/INST] ",
|
||||
],
|
||||
"chosen": [
|
||||
"hi nice to meet you",
|
||||
"I am fine",
|
||||
"My name is Mary",
|
||||
"My name is Mary",
|
||||
"Python",
|
||||
"Python",
|
||||
"Python",
|
||||
"$46 as of 10am EST",
|
||||
"46 as of 10am EST",
|
||||
],
|
||||
"rejected": [
|
||||
"leave me alone",
|
||||
"I am not fine",
|
||||
"Whats it to you?",
|
||||
"I dont have a name",
|
||||
"Javascript",
|
||||
"C++",
|
||||
"Java",
|
||||
" $46 as of 10am EST",
|
||||
" 46 as of 10am EST",
|
||||
],
|
||||
}
|
||||
# fmt: on
|
||||
return Dataset.from_dict(dummy_dataset_dict)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
["gpt2", "sigmoid"],
|
||||
["t5", "hinge"],
|
||||
["gpt2", "ipo"],
|
||||
["t5", "ipo"],
|
||||
]
|
||||
)
|
||||
def test_cpo_trainer(self, name, loss_type):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
training_args = CPOConfig(
|
||||
output_dir=tmp_dir,
|
||||
per_device_train_batch_size=2,
|
||||
max_steps=3,
|
||||
remove_unused_columns=False,
|
||||
gradient_accumulation_steps=1,
|
||||
learning_rate=9e-1,
|
||||
evaluation_strategy="steps",
|
||||
beta=0.1,
|
||||
loss_type=loss_type,
|
||||
)
|
||||
|
||||
dummy_dataset = self._init_dummy_dataset()
|
||||
|
||||
if name == "gpt2":
|
||||
model = self.model
|
||||
tokenizer = self.tokenizer
|
||||
elif name == "t5":
|
||||
model = self.t5_model
|
||||
tokenizer = self.t5_tokenizer
|
||||
training_args.is_encoder_decoder = True
|
||||
|
||||
trainer = CPOTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=dummy_dataset,
|
||||
eval_dataset=dummy_dataset,
|
||||
)
|
||||
|
||||
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
||||
|
||||
trainer.train()
|
||||
|
||||
assert trainer.state.log_history[-1]["train_loss"] is not None
|
||||
|
||||
# check the params have changed
|
||||
for n, param in previous_trainable_params.items():
|
||||
new_param = trainer.model.get_parameter(n)
|
||||
# check the params have changed - ignore 0 biases
|
||||
if param.sum() != 0:
|
||||
assert not torch.equal(param, new_param)
|
||||
|
||||
@require_peft
|
||||
@mark.peft_test
|
||||
def test_cpo_trainer_with_lora(self):
|
||||
from peft import LoraConfig
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
lora_dropout=0.05,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
training_args = CPOConfig(
|
||||
output_dir=tmp_dir,
|
||||
per_device_train_batch_size=2,
|
||||
max_steps=3,
|
||||
remove_unused_columns=False,
|
||||
gradient_accumulation_steps=4,
|
||||
learning_rate=9e-1,
|
||||
evaluation_strategy="steps",
|
||||
beta=0.1,
|
||||
)
|
||||
|
||||
dummy_dataset = self._init_dummy_dataset()
|
||||
|
||||
trainer = CPOTrainer(
|
||||
model=self.model,
|
||||
args=training_args,
|
||||
tokenizer=self.tokenizer,
|
||||
train_dataset=dummy_dataset,
|
||||
eval_dataset=dummy_dataset,
|
||||
peft_config=lora_config,
|
||||
)
|
||||
|
||||
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
||||
|
||||
trainer.train()
|
||||
|
||||
assert trainer.state.log_history[-1]["train_loss"] is not None
|
||||
|
||||
# check the params have changed
|
||||
for n, param in previous_trainable_params.items():
|
||||
if "lora" in n:
|
||||
new_param = trainer.model.get_parameter(n)
|
||||
# check the params have changed - ignore 0 biases
|
||||
if param.sum() != 0:
|
||||
assert not torch.equal(param, new_param)
|
@ -587,3 +587,64 @@ class DPOTrainerTester(unittest.TestCase):
|
||||
)
|
||||
|
||||
assert trainer.model.model_tags == trainer._tag_names
|
||||
|
||||
@require_peft
|
||||
@mark.peft_test
|
||||
def test_dpo_lora_force_use_ref(self):
|
||||
from peft import LoraConfig, get_peft_model
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
lora_dropout=0.05,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
|
||||
# lora model
|
||||
model = AutoModelForCausalLM.from_pretrained(self.model_id)
|
||||
model_peft = get_peft_model(model, lora_config)
|
||||
|
||||
ref_model = AutoModelForCausalLM.from_pretrained(self.model_id)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
training_args = TrainingArguments(
|
||||
output_dir=tmp_dir,
|
||||
per_device_train_batch_size=2,
|
||||
max_steps=3,
|
||||
remove_unused_columns=False,
|
||||
gradient_accumulation_steps=4,
|
||||
learning_rate=9e-1,
|
||||
evaluation_strategy="steps",
|
||||
)
|
||||
|
||||
dummy_dataset = self._init_dummy_dataset()
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# passing a peft_model as model and ref_model should error out,
|
||||
# unless you pass `force_use_ref_model`
|
||||
trainer = DPOTrainer(
|
||||
model=model_peft,
|
||||
ref_model=ref_model,
|
||||
beta=0.1,
|
||||
args=training_args,
|
||||
tokenizer=self.tokenizer,
|
||||
train_dataset=dummy_dataset,
|
||||
eval_dataset=dummy_dataset,
|
||||
peft_config=lora_config,
|
||||
)
|
||||
|
||||
trainer = DPOTrainer(
|
||||
model=model_peft,
|
||||
ref_model=ref_model,
|
||||
beta=0.1,
|
||||
args=training_args,
|
||||
tokenizer=self.tokenizer,
|
||||
train_dataset=dummy_dataset,
|
||||
eval_dataset=dummy_dataset,
|
||||
peft_config=lora_config,
|
||||
force_use_ref_model=True,
|
||||
)
|
||||
|
||||
# train the model
|
||||
trainer.train()
|
||||
|
@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
import tempfile
|
||||
import unittest
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
@ -31,15 +32,27 @@ class IterativeTrainerTester(unittest.TestCase):
|
||||
cls.tokenizer.pad_token = cls.tokenizer.eos_token
|
||||
|
||||
# get t5 as seq2seq example:
|
||||
model_id = "trl-internal-testing/tiny-T5ForConditionalGeneration-correct-vocab"
|
||||
model_id = "trl-internal-testing/tiny-T5ForConditionalGeneration-correct-vocab-calibrated"
|
||||
cls.t5_model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
|
||||
cls.t5_tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
def _init_tensor_dummy_dataset(self):
|
||||
dummy_dataset_dict = {
|
||||
"input_ids": [torch.tensor([5303, 3621]), torch.tensor([3666, 1438, 318]), torch.tensor([5303, 3621])],
|
||||
"attention_mask": [torch.tensor([1, 1]), torch.tensor([1, 1, 1]), torch.tensor([1, 1])],
|
||||
"labels": [torch.tensor([5303, 3621]), torch.tensor([3666, 1438, 318]), torch.tensor([5303, 3621])],
|
||||
"input_ids": [
|
||||
torch.tensor([5303, 3621, 3666, 1438, 318]),
|
||||
torch.tensor([3666, 1438, 318, 3666, 1438, 318]),
|
||||
torch.tensor([5303, 3621, 3666, 1438, 318]),
|
||||
],
|
||||
"attention_mask": [
|
||||
torch.tensor([1, 1, 1, 1, 1]),
|
||||
torch.tensor([1, 1, 1, 1, 1, 1]),
|
||||
torch.tensor([1, 1, 1, 1, 1]),
|
||||
],
|
||||
"labels": [
|
||||
torch.tensor([5303, 3621, 3666, 1438, 318]),
|
||||
torch.tensor([3666, 1438, 318, 3666, 1438, 318]),
|
||||
torch.tensor([5303, 3621, 3666, 1438, 318]),
|
||||
],
|
||||
}
|
||||
|
||||
dummy_dataset = Dataset.from_dict(dummy_dataset_dict)
|
||||
@ -94,11 +107,10 @@ class IterativeTrainerTester(unittest.TestCase):
|
||||
tokenizer = self.t5_tokenizer
|
||||
|
||||
args = TrainingArguments(
|
||||
output_dir=tmp_dir,
|
||||
per_device_train_batch_size=2,
|
||||
max_steps=2,
|
||||
output_dir=tmp_dir, per_device_train_batch_size=2, max_steps=2, learning_rate=1e-3
|
||||
)
|
||||
iterative_trainer = IterativeSFTTrainer(model=model, args=args, tokenizer=tokenizer)
|
||||
iterative_trainer.optimizer.zero_grad = partial(iterative_trainer.optimizer.zero_grad, set_to_none=False)
|
||||
|
||||
iterative_trainer.step(**inputs)
|
||||
|
||||
|
387
tests/test_kto_trainer.py
Normal file
387
tests/test_kto_trainer.py
Normal file
@ -0,0 +1,387 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from parameterized import parameterized
|
||||
from pytest import mark
|
||||
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer
|
||||
|
||||
from trl import KTOConfig, KTOTrainer
|
||||
from trl.trainer.kto_trainer import _get_kl_dataset, _process_tokens, _tokenize
|
||||
|
||||
from .testing_utils import require_no_wandb, require_peft
|
||||
|
||||
|
||||
class KTOTrainerTester(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model_id = "trl-internal-testing/dummy-GPT2-correct-vocab"
|
||||
cls.model = AutoModelForCausalLM.from_pretrained(cls.model_id)
|
||||
cls.ref_model = AutoModelForCausalLM.from_pretrained(cls.model_id)
|
||||
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_id)
|
||||
cls.tokenizer.pad_token = cls.tokenizer.eos_token
|
||||
|
||||
# get t5 as seq2seq example:
|
||||
model_id = "trl-internal-testing/tiny-T5ForConditionalGeneration-correct-vocab"
|
||||
cls.t5_model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
|
||||
cls.t5_ref_model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
|
||||
cls.t5_tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
def _init_dummy_dataset(self):
|
||||
# fmt: off
|
||||
dummy_dataset_dict = {
|
||||
"prompt": [
|
||||
"Hey, hello",
|
||||
"How are you",
|
||||
"What is your name?",
|
||||
"What is your name?",
|
||||
"Which is the best programming language?",
|
||||
"Which is the best programming language?",
|
||||
"Which is the best programming language?",
|
||||
],
|
||||
"completion": [
|
||||
"hi nice to meet you",
|
||||
"leave me alone",
|
||||
"I don't have a name",
|
||||
"My name is Mary",
|
||||
"Python",
|
||||
"C++",
|
||||
"Java",
|
||||
],
|
||||
"label": [
|
||||
True,
|
||||
False,
|
||||
False,
|
||||
True,
|
||||
True,
|
||||
False,
|
||||
False,
|
||||
],
|
||||
}
|
||||
# fmt: on
|
||||
return Dataset.from_dict(dummy_dataset_dict)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
["gpt2", True, True],
|
||||
["gpt2", True, False],
|
||||
# ["t5", True],
|
||||
["gpt2", False, True],
|
||||
["gpt2", False, False],
|
||||
# ["t5", False],
|
||||
]
|
||||
)
|
||||
def test_kto_trainer(self, name, pre_compute, eval_dataset):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
training_args = KTOConfig(
|
||||
output_dir=tmp_dir,
|
||||
per_device_train_batch_size=2,
|
||||
max_steps=3,
|
||||
remove_unused_columns=False,
|
||||
gradient_accumulation_steps=1,
|
||||
learning_rate=9e-1,
|
||||
evaluation_strategy="steps",
|
||||
beta=0.1,
|
||||
precompute_ref_log_probs=pre_compute,
|
||||
)
|
||||
|
||||
dummy_dataset = self._init_dummy_dataset()
|
||||
|
||||
if name == "gpt2":
|
||||
model = self.model
|
||||
ref_model = self.ref_model
|
||||
tokenizer = self.tokenizer
|
||||
elif name == "t5":
|
||||
model = self.t5_model
|
||||
ref_model = self.t5_ref_model
|
||||
tokenizer = self.t5_tokenizer
|
||||
|
||||
trainer = KTOTrainer(
|
||||
model=model,
|
||||
ref_model=ref_model,
|
||||
args=training_args,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=dummy_dataset,
|
||||
eval_dataset=dummy_dataset if eval_dataset else None,
|
||||
)
|
||||
|
||||
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
||||
|
||||
trainer.train()
|
||||
|
||||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||
|
||||
# check the params have changed
|
||||
for n, param in previous_trainable_params.items():
|
||||
new_param = trainer.model.get_parameter(n)
|
||||
# check the params have changed - ignore 0 biases
|
||||
if param.sum() != 0:
|
||||
self.assertFalse(torch.equal(param, new_param))
|
||||
|
||||
def test_tokenize_and_process_tokens(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
training_args = KTOConfig(
|
||||
output_dir=tmp_dir,
|
||||
per_device_train_batch_size=2,
|
||||
max_steps=3,
|
||||
remove_unused_columns=False,
|
||||
gradient_accumulation_steps=1,
|
||||
learning_rate=9e-1,
|
||||
evaluation_strategy="steps",
|
||||
beta=0.1,
|
||||
)
|
||||
|
||||
dummy_dataset = self._init_dummy_dataset()
|
||||
|
||||
trainer = KTOTrainer(
|
||||
model=self.model,
|
||||
ref_model=self.ref_model,
|
||||
args=training_args,
|
||||
tokenizer=self.tokenizer,
|
||||
train_dataset=dummy_dataset,
|
||||
eval_dataset=dummy_dataset,
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tokenized_dataset = dummy_dataset.map(
|
||||
_tokenize,
|
||||
fn_kwargs={"tokenizer": trainer.tokenizer},
|
||||
batched=True,
|
||||
batch_size=2,
|
||||
)
|
||||
self.assertListEqual(tokenized_dataset["prompt"], dummy_dataset["prompt"])
|
||||
self.assertListEqual(tokenized_dataset["completion"], dummy_dataset["completion"])
|
||||
self.assertListEqual(tokenized_dataset["label"], dummy_dataset["label"])
|
||||
self.assertListEqual(tokenized_dataset["prompt_input_ids"][0], [10814, 11])
|
||||
self.assertListEqual(tokenized_dataset["prompt_attention_mask"][0], [1, 1])
|
||||
self.assertListEqual(tokenized_dataset["answer_input_ids"][0], [5968, 1219, 72, 3621, 284, 1826, 345])
|
||||
self.assertListEqual(tokenized_dataset["answer_attention_mask"][0], [1, 1, 1, 1, 1, 1, 1])
|
||||
|
||||
# Test reversal of (prompt, completion) pairs for KL dataset
|
||||
tokenized_kl_dataset = tokenized_dataset.map(_get_kl_dataset, batched=True, batch_size=2)
|
||||
self.assertListEqual(
|
||||
tokenized_kl_dataset["prompt_input_ids"][0], tokenized_dataset["prompt_input_ids"][0]
|
||||
)
|
||||
self.assertListEqual(
|
||||
tokenized_kl_dataset["prompt_attention_mask"][0], tokenized_dataset["prompt_attention_mask"][0]
|
||||
)
|
||||
self.assertListEqual(
|
||||
tokenized_kl_dataset["answer_input_ids"][0], tokenized_dataset["answer_input_ids"][1]
|
||||
)
|
||||
self.assertListEqual(
|
||||
tokenized_kl_dataset["answer_attention_mask"][0], tokenized_dataset["answer_attention_mask"][1]
|
||||
)
|
||||
|
||||
fn_kwargs = {
|
||||
"prefix": "",
|
||||
"is_encoder_decoder": trainer.is_encoder_decoder,
|
||||
"tokenizer": trainer.tokenizer,
|
||||
"max_length": trainer.max_length,
|
||||
"truncation_mode": trainer.truncation_mode,
|
||||
"label_pad_token_id": trainer.label_pad_token_id,
|
||||
"max_prompt_length": trainer.max_prompt_length,
|
||||
}
|
||||
processed_dataset = tokenized_dataset.map(_process_tokens, fn_kwargs=fn_kwargs, num_proc=2)
|
||||
self.assertListEqual(processed_dataset["prompt"], dummy_dataset["prompt"])
|
||||
self.assertListEqual(processed_dataset["completion"], dummy_dataset["completion"])
|
||||
self.assertListEqual(processed_dataset["label"], dummy_dataset["label"])
|
||||
self.assertListEqual(processed_dataset["prompt_input_ids"][0], [50256, 10814, 11])
|
||||
self.assertListEqual(processed_dataset["prompt_attention_mask"][0], [1, 1, 1])
|
||||
self.assertListEqual(
|
||||
processed_dataset["completion_input_ids"][0],
|
||||
[50256, 10814, 11, 5968, 1219, 72, 3621, 284, 1826, 345, 50256],
|
||||
)
|
||||
self.assertListEqual(
|
||||
processed_dataset["completion_attention_mask"][0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
|
||||
)
|
||||
self.assertListEqual(
|
||||
processed_dataset["completion_labels"][0],
|
||||
[-100, -100, -100, 5968, 1219, 72, 3621, 284, 1826, 345, 50256],
|
||||
)
|
||||
|
||||
def test_kto_trainer_without_providing_ref_model(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
training_args = KTOConfig(
|
||||
output_dir=tmp_dir,
|
||||
per_device_train_batch_size=2,
|
||||
max_steps=3,
|
||||
remove_unused_columns=False,
|
||||
gradient_accumulation_steps=4,
|
||||
learning_rate=9e-1,
|
||||
evaluation_strategy="steps",
|
||||
beta=0.1,
|
||||
)
|
||||
|
||||
dummy_dataset = self._init_dummy_dataset()
|
||||
|
||||
trainer = KTOTrainer(
|
||||
model=self.model,
|
||||
ref_model=None,
|
||||
args=training_args,
|
||||
tokenizer=self.tokenizer,
|
||||
train_dataset=dummy_dataset,
|
||||
eval_dataset=dummy_dataset,
|
||||
)
|
||||
|
||||
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
||||
|
||||
trainer.train()
|
||||
|
||||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||
|
||||
# check the params have changed
|
||||
for n, param in previous_trainable_params.items():
|
||||
new_param = trainer.model.get_parameter(n)
|
||||
# check the params have changed - ignore 0 biases
|
||||
if param.sum() != 0:
|
||||
self.assertFalse(torch.equal(param, new_param))
|
||||
|
||||
@require_peft
|
||||
@mark.peft_test
|
||||
def test_kto_trainer_without_providing_ref_model_with_lora(self):
|
||||
from peft import LoraConfig
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
lora_dropout=0.05,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
training_args = KTOConfig(
|
||||
output_dir=tmp_dir,
|
||||
per_device_train_batch_size=2,
|
||||
max_steps=3,
|
||||
remove_unused_columns=False,
|
||||
gradient_accumulation_steps=4,
|
||||
learning_rate=9e-1,
|
||||
evaluation_strategy="steps",
|
||||
beta=0.1,
|
||||
)
|
||||
|
||||
dummy_dataset = self._init_dummy_dataset()
|
||||
|
||||
trainer = KTOTrainer(
|
||||
model=self.model,
|
||||
ref_model=None,
|
||||
args=training_args,
|
||||
tokenizer=self.tokenizer,
|
||||
train_dataset=dummy_dataset,
|
||||
eval_dataset=dummy_dataset,
|
||||
peft_config=lora_config,
|
||||
)
|
||||
|
||||
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
||||
|
||||
trainer.train()
|
||||
|
||||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||
|
||||
# check the params have changed
|
||||
for n, param in previous_trainable_params.items():
|
||||
if "lora" in n:
|
||||
new_param = trainer.model.get_parameter(n)
|
||||
# check the params have changed - ignore 0 biases
|
||||
if param.sum() != 0:
|
||||
self.assertFalse(torch.equal(param, new_param))
|
||||
|
||||
@require_no_wandb
|
||||
def test_kto_trainer_generate_during_eval_no_wandb(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
training_args = KTOConfig(
|
||||
output_dir=tmp_dir,
|
||||
per_device_train_batch_size=2,
|
||||
max_steps=3,
|
||||
remove_unused_columns=False,
|
||||
gradient_accumulation_steps=1,
|
||||
learning_rate=9e-1,
|
||||
evaluation_strategy="steps",
|
||||
beta=0.1,
|
||||
generate_during_eval=True,
|
||||
)
|
||||
|
||||
dummy_dataset = self._init_dummy_dataset()
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
expected_regex="`generate_during_eval=True` requires Weights and Biases to be installed."
|
||||
" Please install with `pip install wandb` to resolve.",
|
||||
):
|
||||
KTOTrainer(
|
||||
model=self.model,
|
||||
ref_model=None,
|
||||
args=training_args,
|
||||
tokenizer=self.tokenizer,
|
||||
train_dataset=dummy_dataset,
|
||||
eval_dataset=dummy_dataset,
|
||||
)
|
||||
|
||||
@require_peft
|
||||
@mark.peft_test
|
||||
def test_kto_lora_save(self):
|
||||
from peft import LoraConfig, get_peft_model
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
lora_dropout=0.05,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
|
||||
# lora model
|
||||
model = AutoModelForCausalLM.from_pretrained(self.model_id)
|
||||
model_peft = get_peft_model(model, lora_config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
training_args = KTOConfig(
|
||||
output_dir=tmp_dir,
|
||||
per_device_train_batch_size=2,
|
||||
max_steps=3,
|
||||
remove_unused_columns=False,
|
||||
gradient_accumulation_steps=4,
|
||||
learning_rate=9e-1,
|
||||
evaluation_strategy="steps",
|
||||
beta=0.1,
|
||||
)
|
||||
|
||||
dummy_dataset = self._init_dummy_dataset()
|
||||
|
||||
# kto train lora model with a lora config
|
||||
trainer = KTOTrainer(
|
||||
model=model_peft,
|
||||
ref_model=None,
|
||||
args=training_args,
|
||||
tokenizer=self.tokenizer,
|
||||
train_dataset=dummy_dataset,
|
||||
eval_dataset=dummy_dataset,
|
||||
peft_config=lora_config,
|
||||
)
|
||||
|
||||
# train the model
|
||||
trainer.train()
|
||||
|
||||
# save peft adapter
|
||||
trainer.save_model()
|
||||
|
||||
# assert that the model is loaded without giving OSError
|
||||
try:
|
||||
AutoModelForCausalLM.from_pretrained(tmp_dir)
|
||||
except OSError:
|
||||
self.fail("Loading the saved peft adapter failed")
|
@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
import sys
|
||||
import unittest
|
||||
from functools import partial
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@ -134,6 +135,7 @@ class TestPeftDependancy(unittest.TestCase):
|
||||
tokenizer=tokenizer,
|
||||
dataset=dummy_dataset,
|
||||
)
|
||||
ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False)
|
||||
dummy_dataloader = ppo_trainer.dataloader
|
||||
|
||||
for query_tensor, response_tensor in dummy_dataloader:
|
||||
|
174
tests/test_orpo_trainer.py
Normal file
174
tests/test_orpo_trainer.py
Normal file
@ -0,0 +1,174 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from parameterized import parameterized
|
||||
from pytest import mark
|
||||
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer
|
||||
|
||||
from trl import ORPOConfig, ORPOTrainer
|
||||
|
||||
from .testing_utils import require_peft
|
||||
|
||||
|
||||
class ORPOTrainerTester(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model_id = "trl-internal-testing/dummy-GPT2-correct-vocab"
|
||||
cls.model = AutoModelForCausalLM.from_pretrained(cls.model_id)
|
||||
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_id)
|
||||
cls.tokenizer.pad_token = cls.tokenizer.eos_token
|
||||
|
||||
# get t5 as seq2seq example:
|
||||
model_id = "trl-internal-testing/tiny-T5ForConditionalGeneration-correct-vocab"
|
||||
cls.t5_model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
|
||||
cls.t5_tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
def _init_dummy_dataset(self):
|
||||
# fmt: off
|
||||
dummy_dataset_dict = {
|
||||
"prompt": [
|
||||
"hello",
|
||||
"how are you",
|
||||
"What is your name?",
|
||||
"What is your name?",
|
||||
"Which is the best programming language?",
|
||||
"Which is the best programming language?",
|
||||
"Which is the best programming language?",
|
||||
"[INST] How is the stock price? [/INST]",
|
||||
"[INST] How is the stock price? [/INST] ",
|
||||
],
|
||||
"chosen": [
|
||||
"hi nice to meet you",
|
||||
"I am fine",
|
||||
"My name is Mary",
|
||||
"My name is Mary",
|
||||
"Python",
|
||||
"Python",
|
||||
"Python",
|
||||
"$46 as of 10am EST",
|
||||
"46 as of 10am EST",
|
||||
],
|
||||
"rejected": [
|
||||
"leave me alone",
|
||||
"I am not fine",
|
||||
"Whats it to you?",
|
||||
"I dont have a name",
|
||||
"Javascript",
|
||||
"C++",
|
||||
"Java",
|
||||
" $46 as of 10am EST",
|
||||
" 46 as of 10am EST",
|
||||
],
|
||||
}
|
||||
# fmt: on
|
||||
return Dataset.from_dict(dummy_dataset_dict)
|
||||
|
||||
@parameterized.expand([["gpt2"], ["t5"]])
|
||||
def test_orpo_trainer(self, name):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
training_args = ORPOConfig(
|
||||
output_dir=tmp_dir,
|
||||
per_device_train_batch_size=2,
|
||||
max_steps=3,
|
||||
remove_unused_columns=False,
|
||||
gradient_accumulation_steps=1,
|
||||
learning_rate=9e-1,
|
||||
evaluation_strategy="steps",
|
||||
beta=0.1,
|
||||
)
|
||||
|
||||
dummy_dataset = self._init_dummy_dataset()
|
||||
|
||||
if name == "gpt2":
|
||||
model = self.model
|
||||
tokenizer = self.tokenizer
|
||||
elif name == "t5":
|
||||
model = self.t5_model
|
||||
tokenizer = self.t5_tokenizer
|
||||
training_args.is_encoder_decoder = True
|
||||
|
||||
trainer = ORPOTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
tokenizer=tokenizer,
|
||||
train_dataset=dummy_dataset,
|
||||
eval_dataset=dummy_dataset,
|
||||
)
|
||||
|
||||
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
||||
|
||||
trainer.train()
|
||||
|
||||
assert trainer.state.log_history[-1]["train_loss"] is not None
|
||||
|
||||
# check the params have changed
|
||||
for n, param in previous_trainable_params.items():
|
||||
new_param = trainer.model.get_parameter(n)
|
||||
# check the params have changed - ignore 0 biases
|
||||
if param.sum() != 0:
|
||||
assert not torch.equal(param, new_param)
|
||||
|
||||
@require_peft
|
||||
@mark.peft_test
|
||||
def test_orpo_trainer_with_lora(self):
|
||||
from peft import LoraConfig
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
lora_dropout=0.05,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
training_args = ORPOConfig(
|
||||
output_dir=tmp_dir,
|
||||
per_device_train_batch_size=2,
|
||||
max_steps=3,
|
||||
remove_unused_columns=False,
|
||||
gradient_accumulation_steps=4,
|
||||
learning_rate=9e-1,
|
||||
evaluation_strategy="steps",
|
||||
beta=0.1,
|
||||
)
|
||||
|
||||
dummy_dataset = self._init_dummy_dataset()
|
||||
|
||||
trainer = ORPOTrainer(
|
||||
model=self.model,
|
||||
args=training_args,
|
||||
tokenizer=self.tokenizer,
|
||||
train_dataset=dummy_dataset,
|
||||
eval_dataset=dummy_dataset,
|
||||
peft_config=lora_config,
|
||||
)
|
||||
|
||||
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
||||
|
||||
trainer.train()
|
||||
|
||||
assert trainer.state.log_history[-1]["train_loss"] is not None
|
||||
|
||||
# check the params have changed
|
||||
for n, param in previous_trainable_params.items():
|
||||
if "lora" in n:
|
||||
new_param = trainer.model.get_parameter(n)
|
||||
# check the params have changed - ignore 0 biases
|
||||
if param.sum() != 0:
|
||||
assert not torch.equal(param, new_param)
|
@ -17,6 +17,7 @@ import gc
|
||||
import re
|
||||
import tempfile
|
||||
import unittest
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -193,6 +194,7 @@ class PPOTrainerTester(unittest.TestCase):
|
||||
tokenizer=self.gpt2_tokenizer,
|
||||
dataset=dummy_dataset,
|
||||
)
|
||||
ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False)
|
||||
dummy_dataloader = ppo_trainer.dataloader
|
||||
# train model with ppo
|
||||
for query_tensor, response_tensor in dummy_dataloader:
|
||||
@ -220,6 +222,7 @@ class PPOTrainerTester(unittest.TestCase):
|
||||
tokenizer=self.gpt2_tokenizer,
|
||||
dataset=dummy_dataset,
|
||||
)
|
||||
ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False)
|
||||
dummy_dataloader = ppo_trainer.dataloader
|
||||
# train model with ppo
|
||||
for query_tensor, response_tensor in dummy_dataloader:
|
||||
@ -252,6 +255,7 @@ class PPOTrainerTester(unittest.TestCase):
|
||||
tokenizer=self.gpt2_tokenizer,
|
||||
dataset=dummy_dataset,
|
||||
)
|
||||
ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False)
|
||||
dummy_dataloader = ppo_trainer.dataloader
|
||||
|
||||
assert isinstance(ppo_trainer.optimizer.optimizer, torch.optim.SGD)
|
||||
@ -291,6 +295,7 @@ class PPOTrainerTester(unittest.TestCase):
|
||||
dataset=dummy_dataset,
|
||||
lr_scheduler=lr_scheduler,
|
||||
)
|
||||
ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False)
|
||||
dummy_dataloader = ppo_trainer.dataloader
|
||||
|
||||
assert isinstance(ppo_trainer.optimizer.optimizer, torch.optim.SGD)
|
||||
@ -332,6 +337,7 @@ class PPOTrainerTester(unittest.TestCase):
|
||||
tokenizer=self.gpt2_tokenizer,
|
||||
dataset=dummy_dataset,
|
||||
)
|
||||
ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False)
|
||||
dummy_dataloader = ppo_trainer.dataloader
|
||||
# train model with ppo
|
||||
for query_tensor, response_tensor in dummy_dataloader:
|
||||
@ -382,6 +388,7 @@ class PPOTrainerTester(unittest.TestCase):
|
||||
dataset=dummy_dataset,
|
||||
num_shared_layers=num_shared_layers,
|
||||
)
|
||||
ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False)
|
||||
dummy_dataloader = ppo_trainer.dataloader
|
||||
# train model with ppo
|
||||
for query_tensor, response_tensor in dummy_dataloader:
|
||||
@ -449,6 +456,7 @@ class PPOTrainerTester(unittest.TestCase):
|
||||
tokenizer=self.gpt2_tokenizer,
|
||||
dataset=dummy_dataset,
|
||||
)
|
||||
ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False)
|
||||
dummy_dataloader = ppo_trainer.dataloader
|
||||
# train model with ppo
|
||||
for query_tensor, response_tensor in dummy_dataloader:
|
||||
@ -486,6 +494,7 @@ class PPOTrainerTester(unittest.TestCase):
|
||||
tokenizer=self.gpt2_tokenizer,
|
||||
dataset=dummy_dataset,
|
||||
)
|
||||
ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False)
|
||||
dummy_dataloader = ppo_trainer.dataloader
|
||||
# train model with ppo
|
||||
for query_tensor, response_tensor in dummy_dataloader:
|
||||
@ -526,6 +535,7 @@ class PPOTrainerTester(unittest.TestCase):
|
||||
ref_model=self.gpt2_model_ref,
|
||||
tokenizer=self.gpt2_tokenizer,
|
||||
)
|
||||
ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False)
|
||||
# train model with ppo
|
||||
reward = [torch.tensor([1.0])]
|
||||
# train model - this should work fine
|
||||
@ -692,7 +702,7 @@ class PPOTrainerTester(unittest.TestCase):
|
||||
tokenizer=self.gpt2_tokenizer,
|
||||
dataset=dummy_dataset,
|
||||
)
|
||||
|
||||
ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False)
|
||||
dummy_dataloader = ppo_trainer.dataloader
|
||||
|
||||
# train model with ppo
|
||||
@ -879,7 +889,7 @@ class PPOTrainerTester(unittest.TestCase):
|
||||
tokenizer=self.gpt2_tokenizer,
|
||||
dataset=dummy_dataset,
|
||||
)
|
||||
|
||||
ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False)
|
||||
assert ppo_trainer.ref_model is None
|
||||
|
||||
dummy_dataloader = ppo_trainer.dataloader
|
||||
@ -967,7 +977,7 @@ class PPOTrainerTester(unittest.TestCase):
|
||||
tokenizer=self.gpt2_tokenizer,
|
||||
dataset=dummy_dataset,
|
||||
)
|
||||
|
||||
ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False)
|
||||
assert ppo_trainer.ref_model is None
|
||||
|
||||
dummy_dataloader = ppo_trainer.dataloader
|
||||
@ -1132,6 +1142,55 @@ class PPOTrainerTester(unittest.TestCase):
|
||||
|
||||
assert generations_single == generations_batched
|
||||
|
||||
def test_generation_with_ref_model(self):
|
||||
dummy_dataset = self._init_dummy_dataset()
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2")
|
||||
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||
|
||||
# Negate the weights in the last layer of the ref model so it never
|
||||
# outputs the same things as the primary model
|
||||
ref_model = copy.deepcopy(model)
|
||||
lm_head_weight = ref_model.pretrained_model.lm_head.weight
|
||||
lm_head_weight.data = -lm_head_weight.data
|
||||
|
||||
ppo_trainer = PPOTrainer(
|
||||
config=self.ppo_config,
|
||||
model=model,
|
||||
ref_model=ref_model,
|
||||
tokenizer=tokenizer,
|
||||
dataset=dummy_dataset,
|
||||
)
|
||||
|
||||
input_texts = ["this is a test", "this is another, longer test"]
|
||||
|
||||
generation_kwargs = {"do_sample": False, "max_new_tokens": 4, "pad_token_id": tokenizer.eos_token_id}
|
||||
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
model_inputs = [tokenizer(txt, return_tensors="pt").input_ids.squeeze() for txt in input_texts]
|
||||
|
||||
generations_batched, ref_generations_batched = ppo_trainer.generate(
|
||||
model_inputs, batch_size=2, generate_ref_response=True, **generation_kwargs
|
||||
)
|
||||
generations_batched = tokenizer.batch_decode(generations_batched)
|
||||
ref_generations_batched = tokenizer.batch_decode(ref_generations_batched)
|
||||
|
||||
generations_single = []
|
||||
ref_generations_single = []
|
||||
for inputs in model_inputs:
|
||||
generation, ref_generation = ppo_trainer.generate(inputs, generate_ref_response=True, **generation_kwargs)
|
||||
generations_single.append(generation.squeeze())
|
||||
ref_generations_single.append(ref_generation.squeeze())
|
||||
|
||||
generations_single = tokenizer.batch_decode(generations_single)
|
||||
ref_generations_single = tokenizer.batch_decode(ref_generations_single)
|
||||
|
||||
assert generations_single == generations_batched
|
||||
assert ref_generations_single == ref_generations_batched
|
||||
|
||||
assert generations_batched != ref_generations_batched
|
||||
assert generations_single != ref_generations_single
|
||||
|
||||
def test_grad_accumulation(self):
|
||||
dummy_dataset = self._init_dummy_dataset()
|
||||
|
||||
@ -1213,6 +1272,7 @@ class PPOTrainerTester(unittest.TestCase):
|
||||
dataset=dummy_dataset,
|
||||
)
|
||||
|
||||
ppo_trainer.optimizer.zero_grad = partial(ppo_trainer.optimizer.zero_grad, set_to_none=False)
|
||||
dummy_dataloader = ppo_trainer.dataloader
|
||||
# train model with ppo
|
||||
for query_tensor, response_tensor in dummy_dataloader:
|
||||
|
@ -106,11 +106,8 @@ class RewardTrainerTester(unittest.TestCase):
|
||||
|
||||
@require_peft
|
||||
def test_reward_trainer_peft(self):
|
||||
import peft
|
||||
from peft import LoraConfig, TaskType
|
||||
|
||||
peft_version = peft.__version__
|
||||
|
||||
peft_config = LoraConfig(
|
||||
task_type=TaskType.SEQ_CLS,
|
||||
inference_mode=False,
|
||||
@ -172,7 +169,7 @@ class RewardTrainerTester(unittest.TestCase):
|
||||
previous_non_trainable_params = {}
|
||||
|
||||
# due to a change in the way the modules to save are dealt in PEFT.
|
||||
trainable_params_name = ["lora", "score"] if peft_version < "0.3.0" else ["lora", "modules_to_save"]
|
||||
trainable_params_name = ["lora", "modules_to_save"]
|
||||
|
||||
# check gradients are not None
|
||||
for n, param in trainer.model.named_parameters():
|
||||
|
53
tests/test_rich_progress_callback.py
Normal file
53
tests/test_rich_progress_callback.py
Normal file
@ -0,0 +1,53 @@
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from datasets import Dataset
|
||||
from transformers import Trainer, TrainingArguments
|
||||
|
||||
from trl.trainer.utils import RichProgressCallback
|
||||
|
||||
|
||||
class DummyModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.a = nn.Parameter(torch.tensor(1.0))
|
||||
|
||||
def forward(self, x):
|
||||
return self.a * x
|
||||
|
||||
|
||||
class TestRichProgressCallback(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.dummy_model = DummyModel()
|
||||
cls.dummy_train_dataset = Dataset.from_list([{"x": 1.0, "y": 2.0}] * 5)
|
||||
cls.dummy_val_dataset = Dataset.from_list([{"x": 1.0, "y": 2.0}] * 101)
|
||||
|
||||
def test_rich_progress_callback_logging(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
training_args = TrainingArguments(
|
||||
output_dir=tmp_dir,
|
||||
per_device_eval_batch_size=2,
|
||||
per_device_train_batch_size=2,
|
||||
num_train_epochs=4,
|
||||
evaluation_strategy="steps",
|
||||
eval_steps=1,
|
||||
logging_strategy="steps",
|
||||
logging_steps=1,
|
||||
save_strategy="no",
|
||||
report_to="none",
|
||||
disable_tqdm=True,
|
||||
)
|
||||
callbacks = [RichProgressCallback()]
|
||||
trainer = Trainer(
|
||||
model=self.dummy_model,
|
||||
train_dataset=self.dummy_train_dataset,
|
||||
eval_dataset=self.dummy_val_dataset,
|
||||
args=training_args,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
trainer.train()
|
@ -19,14 +19,20 @@ import unittest
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
|
||||
from datasets import Dataset, Image, Sequence
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoProcessor,
|
||||
AutoTokenizer,
|
||||
LlavaForConditionalGeneration,
|
||||
TrainingArguments,
|
||||
)
|
||||
|
||||
from trl import SFTTrainer
|
||||
from trl.import_utils import is_peft_available
|
||||
from trl.import_utils import is_peft_available, is_pil_available
|
||||
from trl.trainer import ConstantLengthDataset, DataCollatorForCompletionOnlyLM
|
||||
|
||||
from .testing_utils import require_peft
|
||||
from .testing_utils import require_peft, requires_pil
|
||||
|
||||
|
||||
def formatting_prompts_func(example):
|
||||
@ -45,6 +51,9 @@ def formatting_prompts_func_batched(example):
|
||||
if is_peft_available():
|
||||
from peft import LoraConfig, PeftModel
|
||||
|
||||
if is_pil_available():
|
||||
from PIL import Image as PILImage
|
||||
|
||||
|
||||
class SFTTrainerTester(unittest.TestCase):
|
||||
r""" """
|
||||
@ -123,6 +132,49 @@ class SFTTrainerTester(unittest.TestCase):
|
||||
]
|
||||
)
|
||||
|
||||
if is_pil_available():
|
||||
cls.dummy_vsft_instruction_dataset = Dataset.from_dict(
|
||||
{
|
||||
"messages": [
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "What is in this image?"}, {"type": "image"}],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "It is random noise."}],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "Oh ye, you are right, what is 1+1"}],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "2"}],
|
||||
},
|
||||
],
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "What is in this image?"}, {"type": "image"}],
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "It is random noise."}],
|
||||
},
|
||||
],
|
||||
],
|
||||
"images": [
|
||||
[PILImage.fromarray((np.random.rand(40, 50, 3) * 255).astype("uint8")).convert("RGBA")],
|
||||
[PILImage.fromarray((np.random.rand(50, 60, 3) * 255).astype("uint8")).convert("RGBA")],
|
||||
],
|
||||
}
|
||||
)
|
||||
cls.dummy_vsft_instruction_dataset = cls.dummy_vsft_instruction_dataset.cast_column(
|
||||
"images", Sequence(Image())
|
||||
)
|
||||
|
||||
cls.train_dataset = ConstantLengthDataset(
|
||||
cls.tokenizer,
|
||||
cls.dummy_dataset,
|
||||
@ -930,3 +982,145 @@ class SFTTrainerTester(unittest.TestCase):
|
||||
)
|
||||
|
||||
assert trainer.model.model_tags == trainer._tag_names
|
||||
|
||||
def test_sft_trainer_eval_packing(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
training_args = TrainingArguments(
|
||||
output_dir=tmp_dir,
|
||||
dataloader_drop_last=True,
|
||||
evaluation_strategy="steps",
|
||||
max_steps=4,
|
||||
eval_steps=2,
|
||||
save_steps=2,
|
||||
per_device_train_batch_size=2,
|
||||
gradient_checkpointing=True,
|
||||
)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=self.model_id,
|
||||
args=training_args,
|
||||
train_dataset=self.dummy_chatml_dataset,
|
||||
eval_dataset=self.dummy_chatml_dataset,
|
||||
packing=True,
|
||||
max_seq_length=32, # make sure there is at least 1 packed sequence
|
||||
eval_packing=False,
|
||||
)
|
||||
|
||||
assert len(trainer.train_dataset["input_ids"]) == 1
|
||||
assert len(trainer.eval_dataset["input_ids"]) != 1
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=self.model_id,
|
||||
args=training_args,
|
||||
train_dataset=self.dummy_chatml_dataset,
|
||||
eval_dataset=self.dummy_chatml_dataset,
|
||||
max_seq_length=32, # make sure there is at least 1 packed sequence
|
||||
packing=True,
|
||||
)
|
||||
|
||||
assert len(trainer.train_dataset["input_ids"]) == 1
|
||||
assert len(trainer.eval_dataset["input_ids"]) == 1
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=self.model_id,
|
||||
args=training_args,
|
||||
train_dataset=self.dummy_chatml_dataset,
|
||||
eval_dataset=self.dummy_chatml_dataset,
|
||||
max_seq_length=32, # make sure there is at least 1 packed sequence
|
||||
packing=False,
|
||||
)
|
||||
|
||||
assert len(trainer.train_dataset["input_ids"]) != 1
|
||||
assert len(trainer.eval_dataset["input_ids"]) != 1
|
||||
|
||||
@requires_pil
|
||||
def test_sft_trainer_skip_prepare_dataset(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
training_args = TrainingArguments(
|
||||
output_dir=tmp_dir,
|
||||
dataloader_drop_last=True,
|
||||
evaluation_strategy="steps",
|
||||
max_steps=4,
|
||||
eval_steps=2,
|
||||
save_steps=2,
|
||||
per_device_train_batch_size=2,
|
||||
gradient_checkpointing=True,
|
||||
remove_unused_columns=False,
|
||||
)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=self.model_id,
|
||||
args=training_args,
|
||||
train_dataset=self.dummy_vsft_instruction_dataset,
|
||||
eval_dataset=self.dummy_vsft_instruction_dataset,
|
||||
dataset_text_field="text", # need a dummy field
|
||||
dataset_kwargs={"skip_prepare_dataset": True},
|
||||
)
|
||||
assert trainer.train_dataset.features == self.dummy_vsft_instruction_dataset.features
|
||||
assert trainer.eval_dataset.features == self.dummy_vsft_instruction_dataset.features
|
||||
|
||||
@requires_pil
|
||||
def test_sft_trainer_llava(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
training_args = TrainingArguments(
|
||||
output_dir=tmp_dir,
|
||||
dataloader_drop_last=True,
|
||||
evaluation_strategy="steps",
|
||||
max_steps=4,
|
||||
eval_steps=2,
|
||||
save_steps=2,
|
||||
per_device_train_batch_size=2,
|
||||
per_device_eval_batch_size=2,
|
||||
remove_unused_columns=False,
|
||||
)
|
||||
tiny_llava = LlavaForConditionalGeneration.from_pretrained(
|
||||
"trl-internal-testing/tiny-random-LlavaForConditionalGeneration"
|
||||
)
|
||||
processor = AutoProcessor.from_pretrained("trl-internal-testing/tiny-random-LlavaForConditionalGeneration")
|
||||
|
||||
processor.tokenizer.chat_template = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. {% for message in messages %}{% if message['role'] == 'user' %}USER: {% else %}ASSISTANT: {% endif %}{% for item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% elif item['type'] == 'image' %}<image>{% endif %}{% endfor %}{% if message['role'] == 'user' %} {% else %}{{eos_token}}{% endif %}{% endfor %}"""
|
||||
|
||||
class LLavaDataCollator:
|
||||
def __init__(self, processor):
|
||||
self.processor = processor
|
||||
|
||||
def __call__(self, examples):
|
||||
texts = []
|
||||
images = []
|
||||
for example in examples:
|
||||
if len(example["images"]) > 1:
|
||||
raise ValueError("This collator only supports one image per example")
|
||||
messages = example["messages"]
|
||||
text = self.processor.tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=False
|
||||
)
|
||||
texts.append(text)
|
||||
images.append(example["images"][0])
|
||||
|
||||
batch = self.processor(texts, images, return_tensors="pt", padding=True)
|
||||
|
||||
labels = batch["input_ids"].clone()
|
||||
if self.processor.tokenizer.pad_token_id is not None:
|
||||
labels[labels == self.processor.tokenizer.pad_token_id] = -100
|
||||
batch["labels"] = labels
|
||||
|
||||
return batch
|
||||
|
||||
data_collator = LLavaDataCollator(processor)
|
||||
|
||||
trainer = SFTTrainer(
|
||||
model=tiny_llava,
|
||||
args=training_args,
|
||||
train_dataset=self.dummy_vsft_instruction_dataset,
|
||||
eval_dataset=self.dummy_vsft_instruction_dataset,
|
||||
dataset_text_field="text", # need a dummy field
|
||||
dataset_kwargs={"skip_prepare_dataset": True},
|
||||
data_collator=data_collator,
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
|
||||
assert trainer.state.log_history[(-1)]["train_loss"] is not None
|
||||
assert trainer.state.log_history[0]["eval_loss"] is not None
|
||||
|
||||
assert "model.safetensors" in os.listdir(tmp_dir + "/checkpoint-2")
|
||||
|
@ -19,6 +19,7 @@ from trl import (
|
||||
is_bitsandbytes_available,
|
||||
is_diffusers_available,
|
||||
is_peft_available,
|
||||
is_pil_available,
|
||||
is_wandb_available,
|
||||
is_xpu_available,
|
||||
)
|
||||
@ -51,6 +52,15 @@ def require_diffusers(test_case):
|
||||
return test_case
|
||||
|
||||
|
||||
def requires_pil(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires PIL. Skips the test if pil is not available.
|
||||
"""
|
||||
if not is_pil_available():
|
||||
test_case = unittest.skip("test requires PIL")(test_case)
|
||||
return test_case
|
||||
|
||||
|
||||
def require_wandb(test_case, required: bool = True):
|
||||
"""
|
||||
Decorator marking a test that requires wandb. Skips the test if wandb is not available.
|
||||
|
175
trl/__init__.py
175
trl/__init__.py
@ -1,44 +1,143 @@
|
||||
# flake8: noqa
|
||||
|
||||
__version__ = "0.7.11"
|
||||
__version__ = "0.8.4"
|
||||
|
||||
from .core import set_seed
|
||||
from .environment import TextEnvironment, TextHistory
|
||||
from .extras import BestOfNSampler
|
||||
from .import_utils import (
|
||||
is_bitsandbytes_available,
|
||||
is_diffusers_available,
|
||||
is_npu_available,
|
||||
is_peft_available,
|
||||
is_wandb_available,
|
||||
is_xpu_available,
|
||||
)
|
||||
from .models import (
|
||||
AutoModelForCausalLMWithValueHead,
|
||||
AutoModelForSeq2SeqLMWithValueHead,
|
||||
PreTrainedModelWrapper,
|
||||
create_reference_model,
|
||||
setup_chat_format,
|
||||
)
|
||||
from .trainer import (
|
||||
DataCollatorForCompletionOnlyLM,
|
||||
DPOTrainer,
|
||||
IterativeSFTTrainer,
|
||||
ModelConfig,
|
||||
PPOConfig,
|
||||
PPOTrainer,
|
||||
RewardConfig,
|
||||
RewardTrainer,
|
||||
SFTTrainer,
|
||||
)
|
||||
from .trainer.utils import get_kbit_device_map, get_peft_config, get_quantization_config
|
||||
from typing import TYPE_CHECKING
|
||||
from .import_utils import _LazyModule, is_diffusers_available, OptionalDependencyNotAvailable
|
||||
|
||||
_import_structure = {
|
||||
"core": [
|
||||
"set_seed",
|
||||
],
|
||||
"environment": [
|
||||
"TextEnvironment",
|
||||
"TextHistory",
|
||||
],
|
||||
"extras": [
|
||||
"BestOfNSampler",
|
||||
],
|
||||
"import_utils": [
|
||||
"is_bitsandbytes_available",
|
||||
"is_diffusers_available",
|
||||
"is_npu_available",
|
||||
"is_peft_available",
|
||||
"is_pil_available",
|
||||
"is_wandb_available",
|
||||
"is_xpu_available",
|
||||
],
|
||||
"models": [
|
||||
"AutoModelForCausalLMWithValueHead",
|
||||
"AutoModelForSeq2SeqLMWithValueHead",
|
||||
"PreTrainedModelWrapper",
|
||||
"create_reference_model",
|
||||
"setup_chat_format",
|
||||
"SUPPORTED_ARCHITECTURES",
|
||||
],
|
||||
"trainer": [
|
||||
"DataCollatorForCompletionOnlyLM",
|
||||
"DPOTrainer",
|
||||
"CPOConfig",
|
||||
"CPOTrainer",
|
||||
"IterativeSFTTrainer",
|
||||
"KTOConfig",
|
||||
"KTOTrainer",
|
||||
"ModelConfig",
|
||||
"ORPOConfig",
|
||||
"ORPOTrainer",
|
||||
"PPOConfig",
|
||||
"PPOTrainer",
|
||||
"RewardConfig",
|
||||
"RewardTrainer",
|
||||
"SFTTrainer",
|
||||
],
|
||||
"commands": [],
|
||||
"commands.cli_utils": ["init_zero_verbose", "SftScriptArguments", "DpoScriptArguments", "TrlParser"],
|
||||
"trainer.utils": ["get_kbit_device_map", "get_peft_config", "get_quantization_config", "RichProgressCallback"],
|
||||
"multitask_prompt_tuning": [
|
||||
"MultitaskPromptEmbedding",
|
||||
"MultitaskPromptTuningConfig",
|
||||
"MultitaskPromptTuningInit",
|
||||
],
|
||||
}
|
||||
|
||||
if is_diffusers_available():
|
||||
from .models import (
|
||||
DDPOPipelineOutput,
|
||||
DDPOSchedulerOutput,
|
||||
DDPOStableDiffusionPipeline,
|
||||
DefaultDDPOStableDiffusionPipeline,
|
||||
try:
|
||||
if not is_diffusers_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["models"].extend(
|
||||
[
|
||||
"DDPOPipelineOutput",
|
||||
"DDPOSchedulerOutput",
|
||||
"DDPOStableDiffusionPipeline",
|
||||
"DefaultDDPOStableDiffusionPipeline",
|
||||
]
|
||||
)
|
||||
_import_structure["trainer"].extend(["DDPOConfig", "DDPOTrainer"])
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .core import set_seed
|
||||
from .environment import TextEnvironment, TextHistory
|
||||
from .extras import BestOfNSampler
|
||||
from .import_utils import (
|
||||
is_bitsandbytes_available,
|
||||
is_diffusers_available,
|
||||
is_npu_available,
|
||||
is_peft_available,
|
||||
is_pil_available,
|
||||
is_wandb_available,
|
||||
is_xpu_available,
|
||||
)
|
||||
from .models import (
|
||||
AutoModelForCausalLMWithValueHead,
|
||||
AutoModelForSeq2SeqLMWithValueHead,
|
||||
PreTrainedModelWrapper,
|
||||
create_reference_model,
|
||||
setup_chat_format,
|
||||
SUPPORTED_ARCHITECTURES,
|
||||
)
|
||||
from .trainer import (
|
||||
DataCollatorForCompletionOnlyLM,
|
||||
DPOTrainer,
|
||||
CPOConfig,
|
||||
CPOTrainer,
|
||||
IterativeSFTTrainer,
|
||||
KTOConfig,
|
||||
KTOTrainer,
|
||||
ModelConfig,
|
||||
ORPOConfig,
|
||||
ORPOTrainer,
|
||||
PPOConfig,
|
||||
PPOTrainer,
|
||||
RewardConfig,
|
||||
RewardTrainer,
|
||||
SFTTrainer,
|
||||
)
|
||||
from .trainer.utils import get_kbit_device_map, get_peft_config, get_quantization_config, RichProgressCallback
|
||||
from .commands.cli_utils import init_zero_verbose, SftScriptArguments, DpoScriptArguments, TrlParser
|
||||
|
||||
try:
|
||||
if not is_diffusers_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .models import (
|
||||
DDPOPipelineOutput,
|
||||
DDPOSchedulerOutput,
|
||||
DDPOStableDiffusionPipeline,
|
||||
DefaultDDPOStableDiffusionPipeline,
|
||||
)
|
||||
from .trainer import DDPOConfig, DDPOTrainer
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={"__version__": __version__},
|
||||
)
|
||||
from .trainer import DDPOConfig, DDPOTrainer
|
||||
|
34
trl/commands/__init__.py
Normal file
34
trl/commands/__init__.py
Normal file
@ -0,0 +1,34 @@
|
||||
# flake8: noqa
|
||||
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# flake8: noqa
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from ..import_utils import _LazyModule, OptionalDependencyNotAvailable
|
||||
|
||||
|
||||
_import_structure = {
|
||||
"cli_utils": ["SftArgumentParser", "init_zero_verbose", "DpoScriptArguments", "TrlParser"],
|
||||
"config_parser": ["YamlConfigParser"],
|
||||
}
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .cli_utils import SftScriptArguments, init_zero_verbose, DpoScriptArguments, TrlParser
|
||||
from .config_parser import YamlConfigParser
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
71
trl/commands/cli.py
Normal file
71
trl/commands/cli.py
Normal file
@ -0,0 +1,71 @@
|
||||
# This file is a copy of trl/examples/scripts/sft.py so that we could
|
||||
# use it together with rich and the TRL CLI in a more customizable manner.
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from subprocess import CalledProcessError
|
||||
|
||||
from rich.console import Console
|
||||
|
||||
|
||||
SUPPORTED_COMMANDS = ["sft", "dpo", "chat"]
|
||||
|
||||
|
||||
def main():
|
||||
console = Console()
|
||||
# Make sure to import things locally to avoid verbose from third party libs.
|
||||
with console.status("[bold purple]Welcome! Initializing the TRL CLI..."):
|
||||
from trl.commands.cli_utils import init_zero_verbose
|
||||
|
||||
init_zero_verbose()
|
||||
|
||||
command_name = sys.argv[1]
|
||||
|
||||
if command_name not in SUPPORTED_COMMANDS:
|
||||
raise ValueError(
|
||||
f"Please use one of the supported commands, got {command_name} - supported commands are {SUPPORTED_COMMANDS}"
|
||||
)
|
||||
|
||||
trl_examples_dir = os.path.dirname(__file__)
|
||||
|
||||
# Force-use rich
|
||||
os.environ["TRL_USE_RICH"] = "1"
|
||||
|
||||
if command_name == "chat":
|
||||
command = f"""
|
||||
python {trl_examples_dir}/scripts/{command_name}.py {" ".join(sys.argv[2:])}
|
||||
"""
|
||||
else:
|
||||
command = f"""
|
||||
accelerate launch {trl_examples_dir}/scripts/{command_name}.py {" ".join(sys.argv[2:])}
|
||||
"""
|
||||
|
||||
try:
|
||||
subprocess.run(
|
||||
command.split(),
|
||||
text=True,
|
||||
check=True,
|
||||
encoding="utf-8",
|
||||
cwd=os.getcwd(),
|
||||
env=os.environ.copy(),
|
||||
)
|
||||
except (CalledProcessError, ChildProcessError) as exc:
|
||||
console.log(f"TRL - {command_name.upper()} failed on ! See the logs above for further details.")
|
||||
raise ValueError("TRL CLI failed! Check the traceback above..") from exc
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
288
trl/commands/cli_utils.py
Normal file
288
trl/commands/cli_utils.py
Normal file
@ -0,0 +1,288 @@
|
||||
# This file is a copy of trl/examples/scripts/sft.py so that we could
|
||||
# use it together with rich and the TRL CLI in a more customizable manner.
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from dataclasses import asdict, dataclass, field, fields
|
||||
from typing import Any, List
|
||||
|
||||
import yaml
|
||||
from transformers import HfArgumentParser
|
||||
|
||||
|
||||
class YamlConfigParser:
|
||||
def __init__(self, config_path: str = None, dataclasses: List[Any] = None):
|
||||
self.config = None
|
||||
|
||||
if config_path is not None:
|
||||
with open(config_path) as yaml_file:
|
||||
self.config = yaml.safe_load(yaml_file)
|
||||
else:
|
||||
self.config = {}
|
||||
|
||||
if dataclasses is None:
|
||||
dataclasses = []
|
||||
|
||||
# We create a dummy training args to compare the values before / after
|
||||
# __post_init__
|
||||
# Here we import `TrainingArguments` from the local level to not
|
||||
# break TRL lazy imports.
|
||||
from transformers import TrainingArguments
|
||||
|
||||
self._dummy_training_args = TrainingArguments(output_dir="dummy-training-args")
|
||||
|
||||
self.parse_and_set_env()
|
||||
self.merge_dataclasses(dataclasses)
|
||||
|
||||
def parse_and_set_env(self):
|
||||
if "env" in self.config:
|
||||
env_vars = self.config["env"]
|
||||
if isinstance(env_vars, dict):
|
||||
for key, value in env_vars.items():
|
||||
os.environ[key] = str(value)
|
||||
else:
|
||||
raise ValueError("`env` field should be a dict in the YAML file.")
|
||||
|
||||
def merge_dataclasses(self, dataclasses):
|
||||
from transformers import TrainingArguments
|
||||
|
||||
dataclasses_copy = [deepcopy(dataclass) for dataclass in dataclasses]
|
||||
|
||||
if len(self.config) > 0:
|
||||
for i, dataclass in enumerate(dataclasses):
|
||||
is_hf_training_args = False
|
||||
|
||||
for data_class_field in fields(dataclass):
|
||||
# Get the field here
|
||||
field_name = data_class_field.name
|
||||
field_value = getattr(dataclass, field_name)
|
||||
|
||||
if not isinstance(dataclass, TrainingArguments):
|
||||
default_value = data_class_field.default
|
||||
else:
|
||||
default_value = (
|
||||
getattr(self._dummy_training_args, field_name)
|
||||
if field_name != "output_dir"
|
||||
else field_name
|
||||
)
|
||||
is_hf_training_args = True
|
||||
|
||||
default_value_changed = field_value != default_value
|
||||
|
||||
if field_value is not None or field_name in self.config:
|
||||
if field_name in self.config:
|
||||
# In case the field value is not different from default, overwrite it
|
||||
if not default_value_changed:
|
||||
value_to_replace = self.config[field_name]
|
||||
setattr(dataclasses_copy[i], field_name, value_to_replace)
|
||||
# Otherwise do nothing
|
||||
|
||||
# Re-init `TrainingArguments` to handle all post-processing correctly
|
||||
if is_hf_training_args:
|
||||
init_signature = list(inspect.signature(TrainingArguments.__init__).parameters)
|
||||
dict_dataclass = asdict(dataclasses_copy[i])
|
||||
new_dict_dataclass = {k: v for k, v in dict_dataclass.items() if k in init_signature}
|
||||
dataclasses_copy[i] = TrainingArguments(**new_dict_dataclass)
|
||||
|
||||
return dataclasses_copy
|
||||
|
||||
def to_string(self):
|
||||
final_string = """"""
|
||||
for key, value in self.config.items():
|
||||
if isinstance(value, (dict, list)):
|
||||
if len(value) != 0:
|
||||
value = str(value)
|
||||
value = value.replace("'", '"')
|
||||
value = f"'{value}'"
|
||||
else:
|
||||
continue
|
||||
|
||||
final_string += f"--{key} {value} "
|
||||
return final_string
|
||||
|
||||
|
||||
def init_zero_verbose():
|
||||
"""
|
||||
Perform zero verbose init - use this method on top of the CLI modules to make
|
||||
"""
|
||||
import logging
|
||||
import warnings
|
||||
|
||||
from rich.logging import RichHandler
|
||||
|
||||
FORMAT = "%(message)s"
|
||||
logging.basicConfig(format=FORMAT, datefmt="[%X]", handlers=[RichHandler()], level=logging.ERROR)
|
||||
|
||||
# Custom warning handler to redirect warnings to the logging system
|
||||
def warning_handler(message, category, filename, lineno, file=None, line=None):
|
||||
logging.warning(f"{filename}:{lineno}: {category.__name__}: {message}")
|
||||
|
||||
# Add the custom warning handler - we need to do that before importing anything to make sure the loggers work well
|
||||
warnings.showwarning = warning_handler
|
||||
|
||||
|
||||
@dataclass
|
||||
class SftScriptArguments:
|
||||
dataset_name: str = field(default="timdettmers/openassistant-guanaco", metadata={"help": "the dataset name"})
|
||||
dataset_text_field: str = field(default=None, metadata={"help": "the text field of the dataset"})
|
||||
max_seq_length: int = field(default=512, metadata={"help": "The maximum sequence length for SFT Trainer"})
|
||||
packing: bool = field(default=False, metadata={"help": "Whether to apply data packing or not during training"})
|
||||
config: str = field(default=None, metadata={"help": "Path to the optional config file"})
|
||||
gradient_checkpointing_use_reentrant: bool = field(
|
||||
default=False, metadata={"help": "Whether to apply `use_reentrant` for gradient_checkpointing"}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DpoScriptArguments:
|
||||
dataset_name: str = field(default=None, metadata={"help": "the dataset name"})
|
||||
beta: float = field(default=0.1, metadata={"help": "the beta parameter for DPO loss"})
|
||||
max_length: int = field(default=512, metadata={"help": "max length of each sample"})
|
||||
max_prompt_length: int = field(default=128, metadata={"help": "max length of each sample's prompt"})
|
||||
max_target_length: int = field(
|
||||
default=128, metadata={"help": "Only used for encoder decoder model. Max target of each sample's prompt"}
|
||||
)
|
||||
sanity_check: bool = field(default=False, metadata={"help": "only train on 1000 samples"})
|
||||
ignore_bias_buffers: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "debug argument for distributed training;"
|
||||
"fix for DDP issues with LM bias/mask buffers - invalid scalar type,`inplace operation. See"
|
||||
"https://github.com/huggingface/transformers/issues/22482#issuecomment-1595790992"
|
||||
},
|
||||
)
|
||||
generate_during_eval: bool = field(default=False, metadata={"help": "Generate during evaluation"})
|
||||
config: str = field(default=None, metadata={"help": "Path to the optional config file"})
|
||||
gradient_checkpointing_use_reentrant: bool = field(
|
||||
default=False, metadata={"help": "Whether to apply `use_reentrant` for gradient_checkpointing"}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatArguments:
|
||||
# general settings
|
||||
model_name_or_path: str = field(metadata={"help": "Name of the pre-trained model"})
|
||||
user: str = field(default=None, metadata={"help": "Username to display in chat interface"})
|
||||
system_prompt: str = field(default=None, metadata={"help": "System prompt"})
|
||||
save_folder: str = field(default="./chat_history/", metadata={"help": "Folder to save chat history"})
|
||||
device: str = field(
|
||||
default="cpu",
|
||||
metadata={"help": "device to use for inference."},
|
||||
)
|
||||
config: str = field(
|
||||
default="default",
|
||||
metadata={
|
||||
"help": "Config file used for setting the configs. If `default` uses examples/scripts/config/default_chat_config.yaml"
|
||||
},
|
||||
)
|
||||
examples: str = field(default=None, metadata={"help": "Empty placeholder needs to be set via config."})
|
||||
# generation settings
|
||||
max_new_tokens: int = field(default=256, metadata={"help": "Maximum number of tokens to generate"})
|
||||
do_sample: bool = field(default=True, metadata={"help": "Whether to sample outputs during generation"})
|
||||
num_beams: int = field(default=1, metadata={"help": "Number of beams for beam search"})
|
||||
temperature: float = field(default=1.0, metadata={"help": "Temperature parameter for generation"})
|
||||
top_k: int = field(default=50, metadata={"help": "Value of k for top-k sampling"})
|
||||
top_p: float = field(default=1.0, metadata={"help": "Value of p for nucleus sampling"})
|
||||
repetition_penalty: float = field(default=1.0, metadata={"help": "Repetition penalty"})
|
||||
# model loading
|
||||
model_revision: str = field(
|
||||
default="main",
|
||||
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
||||
)
|
||||
torch_dtype: str = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
|
||||
"dtype will be automatically derived from the model's weights."
|
||||
),
|
||||
"choices": ["auto", "bfloat16", "float16", "float32"],
|
||||
},
|
||||
)
|
||||
trust_remote_code: bool = field(default=False, metadata={"help": "Trust remote code when loading a model."})
|
||||
attn_implementation: str = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": (
|
||||
"Which attention implementation to use; you can run --attn_implementation=flash_attention_2, in which case you must install this manually by running `pip install flash-attn --no-build-isolation`"
|
||||
)
|
||||
},
|
||||
)
|
||||
load_in_8bit: bool = field(
|
||||
default=False, metadata={"help": "use 8 bit precision for the base model - works only with LoRA"}
|
||||
)
|
||||
load_in_4bit: bool = field(
|
||||
default=False, metadata={"help": "use 4 bit precision for the base model - works only with LoRA"}
|
||||
)
|
||||
|
||||
bnb_4bit_quant_type: str = field(default="nf4", metadata={"help": "precise the quantization type (fp4 or nf4)"})
|
||||
use_bnb_nested_quant: bool = field(default=False, metadata={"help": "use nested quantization"})
|
||||
|
||||
|
||||
class TrlParser(HfArgumentParser):
|
||||
def __init__(self, parsers):
|
||||
"""
|
||||
The TRL parser parses a list of parsers (TrainingArguments, trl.ModelConfig, etc.), creates a config
|
||||
parsers for users that pass a valid `config` field and merge the values that are set in the config
|
||||
with the processed parsers.
|
||||
|
||||
Args:
|
||||
parsers (`List[argparse.ArgumentParser`]):
|
||||
List of parsers.
|
||||
"""
|
||||
super().__init__(parsers)
|
||||
|
||||
def post_process_dataclasses(self, dataclasses):
|
||||
# Apply additional post-processing in case some arguments needs a special
|
||||
# care
|
||||
training_args = trl_args = None
|
||||
training_args_index = None
|
||||
|
||||
for i, dataclass_obj in enumerate(dataclasses):
|
||||
if dataclass_obj.__class__.__name__ == "TrainingArguments":
|
||||
training_args = dataclass_obj
|
||||
training_args_index = i
|
||||
elif dataclass_obj.__class__.__name__ in ("SftScriptArguments", "DpoScriptArguments"):
|
||||
trl_args = dataclass_obj
|
||||
else:
|
||||
...
|
||||
|
||||
if trl_args is not None and training_args is not None:
|
||||
training_args.gradient_checkpointing_kwargs = dict(
|
||||
use_reentrant=trl_args.gradient_checkpointing_use_reentrant
|
||||
)
|
||||
dataclasses[training_args_index] = training_args
|
||||
|
||||
return dataclasses
|
||||
|
||||
def parse_args_and_config(self):
|
||||
dataclasses = self.parse_args_into_dataclasses(return_remaining_strings=True)
|
||||
# Pop the last element which should be the remaining strings
|
||||
dataclasses = self.update_dataclasses_with_config(dataclasses[:-1])
|
||||
return dataclasses
|
||||
|
||||
def update_dataclasses_with_config(self, dataclasses):
|
||||
self.config_parser = None
|
||||
for parser_dataclass in dataclasses:
|
||||
if hasattr(parser_dataclass, "config"):
|
||||
if self.config_parser is not None:
|
||||
raise ValueError("You passed the `config` field twice! Make sure to pass `config` only once.")
|
||||
self.config_parser = YamlConfigParser(parser_dataclass.config)
|
||||
|
||||
if self.config_parser is not None:
|
||||
dataclasses = self.config_parser.merge_dataclasses(dataclasses)
|
||||
dataclasses = self.post_process_dataclasses(dataclasses)
|
||||
return dataclasses
|
38
trl/core.py
38
trl/core.py
@ -22,7 +22,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from transformers import top_k_top_p_filtering
|
||||
from transformers.generation import TopKLogitsWarper, TopPLogitsWarper
|
||||
|
||||
from .import_utils import is_npu_available, is_xpu_available
|
||||
|
||||
@ -36,6 +36,42 @@ except ImportError:
|
||||
WANDB_PADDING = -1
|
||||
|
||||
|
||||
def top_k_top_p_filtering(
|
||||
logits: torch.FloatTensor,
|
||||
top_k: int = 0,
|
||||
top_p: float = 1.0,
|
||||
filter_value: float = -float("Inf"),
|
||||
min_tokens_to_keep: int = 1,
|
||||
) -> torch.FloatTensor:
|
||||
"""
|
||||
Filter a distribution of logits using top-k and/or nucleus (top-p) filtering.
|
||||
|
||||
Args:
|
||||
logits: logits distribution shape (batch size, vocabulary size)
|
||||
top_k (`int`, *optional*, defaults to 0):
|
||||
If > 0, only keep the top k tokens with highest probability (top-k filtering)
|
||||
top_p (`float`, *optional*, defaults to 1.0):
|
||||
If < 1.0, only keep the top tokens with cumulative probability >= top_p (nucleus filtering). Nucleus
|
||||
filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
|
||||
min_tokens_to_keep (`int`, *optional*, defaults to 1):
|
||||
Minimumber of tokens we keep per batch example in the output.
|
||||
|
||||
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
|
||||
"""
|
||||
|
||||
if top_k > 0:
|
||||
logits = TopKLogitsWarper(top_k=top_k, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)(
|
||||
None, logits
|
||||
)
|
||||
|
||||
if 0 <= top_p <= 1.0:
|
||||
logits = TopPLogitsWarper(top_p=top_p, filter_value=filter_value, min_tokens_to_keep=min_tokens_to_keep)(
|
||||
None, logits
|
||||
)
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
def flatten_dict(nested: Dict, sep: str = "/") -> Dict:
|
||||
"""Flatten dictionary and concatenate nested keys with separator."""
|
||||
|
||||
|
@ -1,3 +1,14 @@
|
||||
# flake8: noqa
|
||||
from typing import TYPE_CHECKING
|
||||
from ..import_utils import _LazyModule
|
||||
|
||||
from .base_environment import TextEnvironment, TextHistory
|
||||
_import_structure = {
|
||||
"base_environment": ["TextEnvironment", "TextHistory"],
|
||||
}
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .base_environment import TextEnvironment, TextHistory
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
||||
|
@ -13,4 +13,18 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from .best_of_n_sampler import BestOfNSampler
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ..import_utils import _LazyModule
|
||||
|
||||
|
||||
_import_structure = {
|
||||
"best_of_n_sampler": ["BestOfNSampler"],
|
||||
}
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .best_of_n_sampler import BestOfNSampler
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
||||
|
@ -12,7 +12,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
from importlib.util import find_spec
|
||||
from itertools import chain
|
||||
from types import ModuleType
|
||||
from typing import Any
|
||||
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
@ -22,11 +27,11 @@ else:
|
||||
|
||||
|
||||
def is_peft_available() -> bool:
|
||||
return importlib.util.find_spec("peft") is not None
|
||||
return find_spec("peft") is not None
|
||||
|
||||
|
||||
def is_unsloth_available() -> bool:
|
||||
return importlib.util.find_spec("unsloth") is not None
|
||||
return find_spec("unsloth") is not None
|
||||
|
||||
|
||||
def is_accelerate_greater_20_0() -> bool:
|
||||
@ -41,9 +46,16 @@ def is_accelerate_greater_20_0() -> bool:
|
||||
return accelerate_version >= "0.20.0"
|
||||
|
||||
|
||||
def is_transformers_greater_than(version: str) -> bool:
|
||||
_transformers_version = importlib.metadata.version("transformers")
|
||||
return _transformers_version > version
|
||||
def is_transformers_greater_than(current_version: str) -> bool:
|
||||
if _is_python_greater_3_8:
|
||||
from importlib.metadata import version
|
||||
|
||||
_transformers_version = version("transformers")
|
||||
else:
|
||||
import pkg_resources
|
||||
|
||||
_transformers_version = pkg_resources.get_distribution("transformers").version
|
||||
return _transformers_version > current_version
|
||||
|
||||
|
||||
def is_torch_greater_2_0() -> bool:
|
||||
@ -59,26 +71,30 @@ def is_torch_greater_2_0() -> bool:
|
||||
|
||||
|
||||
def is_diffusers_available() -> bool:
|
||||
return importlib.util.find_spec("diffusers") is not None
|
||||
return find_spec("diffusers") is not None
|
||||
|
||||
|
||||
def is_pil_available() -> bool:
|
||||
return find_spec("PIL") is not None
|
||||
|
||||
|
||||
def is_bitsandbytes_available() -> bool:
|
||||
import torch
|
||||
|
||||
# bnb can be imported without GPU but is not usable.
|
||||
return importlib.util.find_spec("bitsandbytes") is not None and torch.cuda.is_available()
|
||||
return find_spec("bitsandbytes") is not None and torch.cuda.is_available()
|
||||
|
||||
|
||||
def is_torchvision_available() -> bool:
|
||||
return importlib.util.find_spec("torchvision") is not None
|
||||
return find_spec("torchvision") is not None
|
||||
|
||||
|
||||
def is_rich_available() -> bool:
|
||||
return importlib.util.find_spec("rich") is not None
|
||||
return find_spec("rich") is not None
|
||||
|
||||
|
||||
def is_wandb_available() -> bool:
|
||||
return importlib.util.find_spec("wandb") is not None
|
||||
return find_spec("wandb") is not None
|
||||
|
||||
|
||||
def is_xpu_available() -> bool:
|
||||
@ -87,7 +103,7 @@ def is_xpu_available() -> bool:
|
||||
|
||||
return accelerate.utils.is_xpu_available()
|
||||
else:
|
||||
if importlib.util.find_spec("intel_extension_for_pytorch") is None:
|
||||
if find_spec("intel_extension_for_pytorch") is None:
|
||||
return False
|
||||
try:
|
||||
import torch
|
||||
@ -99,10 +115,74 @@ def is_xpu_available() -> bool:
|
||||
|
||||
def is_npu_available() -> bool:
|
||||
"""Checks if `torch_npu` is installed and potentially if a NPU is in the environment"""
|
||||
if importlib.util.find_spec("torch") is None or importlib.util.find_spec("torch_npu") is None:
|
||||
if find_spec("torch") is None or find_spec("torch_npu") is None:
|
||||
return False
|
||||
|
||||
import torch
|
||||
import torch_npu # noqa: F401
|
||||
|
||||
return hasattr(torch, "npu") and torch.npu.is_available()
|
||||
|
||||
|
||||
class _LazyModule(ModuleType):
|
||||
"""
|
||||
Module class that surfaces all objects but only performs associated imports when the objects are requested.
|
||||
"""
|
||||
|
||||
# Very heavily inspired by optuna.integration._IntegrationModule
|
||||
# https://github.com/optuna/optuna/blob/master/optuna/integration/__init__.py
|
||||
def __init__(self, name, module_file, import_structure, module_spec=None, extra_objects=None):
|
||||
super().__init__(name)
|
||||
self._modules = set(import_structure.keys())
|
||||
self._class_to_module = {}
|
||||
for key, values in import_structure.items():
|
||||
for value in values:
|
||||
self._class_to_module[value] = key
|
||||
# Needed for autocompletion in an IDE
|
||||
self.__all__ = list(import_structure.keys()) + list(chain(*import_structure.values()))
|
||||
self.__file__ = module_file
|
||||
self.__spec__ = module_spec
|
||||
self.__path__ = [os.path.dirname(module_file)]
|
||||
self._objects = {} if extra_objects is None else extra_objects
|
||||
self._name = name
|
||||
self._import_structure = import_structure
|
||||
|
||||
# Needed for autocompletion in an IDE
|
||||
def __dir__(self):
|
||||
result = super().__dir__()
|
||||
# The elements of self.__all__ that are submodules may or may not be in the dir already, depending on whether
|
||||
# they have been accessed or not. So we only add the elements of self.__all__ that are not already in the dir.
|
||||
for attr in self.__all__:
|
||||
if attr not in result:
|
||||
result.append(attr)
|
||||
return result
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
if name in self._objects:
|
||||
return self._objects[name]
|
||||
if name in self._modules:
|
||||
value = self._get_module(name)
|
||||
elif name in self._class_to_module.keys():
|
||||
module = self._get_module(self._class_to_module[name])
|
||||
value = getattr(module, name)
|
||||
else:
|
||||
raise AttributeError(f"module {self.__name__} has no attribute {name}")
|
||||
|
||||
setattr(self, name, value)
|
||||
return value
|
||||
|
||||
def _get_module(self, module_name: str):
|
||||
try:
|
||||
return importlib.import_module("." + module_name, self.__name__)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Failed to import {self.__name__}.{module_name} because of the following error (look up to see its"
|
||||
f" traceback):\n{e}"
|
||||
) from e
|
||||
|
||||
def __reduce__(self):
|
||||
return (self.__class__, (self._name, self.__file__, self._import_structure))
|
||||
|
||||
|
||||
class OptionalDependencyNotAvailable(BaseException):
|
||||
"""Internally used error class for signalling an optional dependency was not found."""
|
||||
|
@ -13,23 +13,52 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from .modeling_base import PreTrainedModelWrapper, create_reference_model
|
||||
from .modeling_value_head import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead
|
||||
from .utils import setup_chat_format
|
||||
# flake8: noqa
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from ..import_utils import _LazyModule, is_diffusers_available, OptionalDependencyNotAvailable
|
||||
|
||||
|
||||
SUPPORTED_ARCHITECTURES = (
|
||||
AutoModelForCausalLMWithValueHead,
|
||||
AutoModelForSeq2SeqLMWithValueHead,
|
||||
)
|
||||
_import_structure = {
|
||||
"modeling_base": ["PreTrainedModelWrapper", "create_reference_model"],
|
||||
"modeling_value_head": [
|
||||
"AutoModelForCausalLMWithValueHead",
|
||||
"AutoModelForSeq2SeqLMWithValueHead",
|
||||
],
|
||||
"utils": ["setup_chat_format", "SUPPORTED_ARCHITECTURES", "unwrap_model_for_generation"],
|
||||
}
|
||||
|
||||
from ..import_utils import is_diffusers_available
|
||||
try:
|
||||
if not is_diffusers_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["modeling_sd_base"] = [
|
||||
"DDPOPipelineOutput",
|
||||
"DDPOSchedulerOutput",
|
||||
"DDPOStableDiffusionPipeline",
|
||||
"DefaultDDPOStableDiffusionPipeline",
|
||||
]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .modeling_base import PreTrainedModelWrapper, create_reference_model
|
||||
from .modeling_value_head import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead
|
||||
from .utils import setup_chat_format, SUPPORTED_ARCHITECTURES
|
||||
|
||||
if is_diffusers_available():
|
||||
from .modeling_sd_base import (
|
||||
DDPOPipelineOutput,
|
||||
DDPOSchedulerOutput,
|
||||
DDPOStableDiffusionPipeline,
|
||||
DefaultDDPOStableDiffusionPipeline,
|
||||
)
|
||||
try:
|
||||
if not is_diffusers_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .modeling_sd_base import (
|
||||
DDPOPipelineOutput,
|
||||
DDPOSchedulerOutput,
|
||||
DDPOStableDiffusionPipeline,
|
||||
DefaultDDPOStableDiffusionPipeline,
|
||||
)
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
||||
|
@ -15,6 +15,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM
|
||||
|
||||
from ..import_utils import is_npu_available, is_xpu_available
|
||||
from .modeling_base import PreTrainedModelWrapper
|
||||
|
||||
|
||||
@ -245,7 +246,13 @@ class AutoModelForCausalLMWithValueHead(PreTrainedModelWrapper):
|
||||
)
|
||||
|
||||
first_device = list(set(self.pretrained_model.hf_device_map.values()))[0]
|
||||
|
||||
if isinstance(first_device, int):
|
||||
if is_npu_available():
|
||||
first_device = f"npu:{first_device}"
|
||||
elif is_xpu_available():
|
||||
first_device = f"xpu:{first_device}"
|
||||
else:
|
||||
first_device = f"cuda:{first_device}"
|
||||
self.v_head = self.v_head.to(first_device)
|
||||
|
||||
def set_device_hook(module, input, outputs):
|
||||
|
@ -1,8 +1,29 @@
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Literal, Optional, Tuple, Union
|
||||
|
||||
from accelerate.utils import is_deepspeed_available
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||
|
||||
from .modeling_value_head import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead
|
||||
|
||||
|
||||
SUPPORTED_ARCHITECTURES = (
|
||||
AutoModelForCausalLMWithValueHead,
|
||||
AutoModelForSeq2SeqLMWithValueHead,
|
||||
)
|
||||
|
||||
|
||||
if is_deepspeed_available():
|
||||
import deepspeed
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from accelerate import Accelerator
|
||||
from deepspeed.runtime.engine import DeepSpeedEngine
|
||||
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||
|
||||
from .modeling_base import PreTrainedModelWrapper
|
||||
|
||||
|
||||
# TODO: Add Abstract Base Class if more formats are added
|
||||
@dataclass
|
||||
@ -76,10 +97,60 @@ def setup_chat_format(
|
||||
model.resize_token_embeddings(
|
||||
len(tokenizer), pad_to_multiple_of=resize_to_multiple_of if resize_to_multiple_of is not None else None
|
||||
)
|
||||
# Make sure to update the generation config to use the new eos & bos token
|
||||
# Update the model config to use the new eos & bos tokens
|
||||
if getattr(model, "config", None) is not None:
|
||||
model.config.pad_token_id = tokenizer.pad_token_id
|
||||
model.config.bos_token_id = tokenizer.bos_token_id
|
||||
model.config.eos_token_id = tokenizer.eos_token_id
|
||||
# Update the generation config to use the new eos & bos token
|
||||
if getattr(model, "generation_config", None) is not None:
|
||||
model.generation_config.bos_token_id = tokenizer.bos_token_id
|
||||
model.generation_config.eos_token_id = tokenizer.eos_token_id
|
||||
model.generation_config.pad_token_id = tokenizer.pad_token_id
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def remove_hooks(model: "DeepSpeedEngine") -> None:
|
||||
"""Removes the optimizer hooks from a DeepSpeed ZeRO-3 model."""
|
||||
if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"):
|
||||
optimizer_offload = model.optimizer.parameter_offload
|
||||
elif model.optimizer is not None:
|
||||
optimizer_offload = model.optimizer
|
||||
|
||||
for hook in optimizer_offload.forward_hooks:
|
||||
hook.remove()
|
||||
for hook in optimizer_offload.backward_hooks:
|
||||
hook.remove()
|
||||
|
||||
optimizer_offload.forward_hooks = []
|
||||
optimizer_offload.backward_hooks = []
|
||||
|
||||
|
||||
def add_hooks(model: "DeepSpeedEngine") -> None:
|
||||
"""Adds the optimizer hooks from a DeepSpeed ZeRO-3 model."""
|
||||
if model.optimizer is not None and hasattr(model.optimizer, "parameter_offload"):
|
||||
optimizer_offload = model.optimizer.parameter_offload
|
||||
elif model.optimizer is not None:
|
||||
optimizer_offload = model.optimizer
|
||||
optimizer_offload._register_hooks_recursively(optimizer_offload.module)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def unwrap_model_for_generation(
|
||||
model: Union["DistributedDataParallel", "DeepSpeedEngine"], accelerator: "Accelerator", is_peft_model: bool = False
|
||||
) -> Union["PreTrainedModelWrapper", "DeepSpeedEngine"]:
|
||||
"""Context manager to unwrap a model for generation.
|
||||
|
||||
For ZeRO-3 models, we gather the weights once to speed up generation.
|
||||
"""
|
||||
unwrapped_model = accelerator.unwrap_model(model)
|
||||
if is_peft_model:
|
||||
unwrapped_model.pretrained_model.disable_adapter()
|
||||
if accelerator.state.deepspeed_plugin is not None and accelerator.state.deepspeed_plugin.zero_stage == 3:
|
||||
with deepspeed.zero.GatheredParameters(model.parameters()):
|
||||
remove_hooks(model)
|
||||
yield model
|
||||
add_hooks(model)
|
||||
else:
|
||||
yield unwrapped_model
|
||||
|
@ -15,32 +15,92 @@
|
||||
# limitations under the License.
|
||||
|
||||
# There is a circular import in the PPOTrainer if we let isort sort these
|
||||
# isort: off
|
||||
from .utils import (
|
||||
AdaptiveKLController,
|
||||
FixedKLController,
|
||||
ConstantLengthDataset,
|
||||
DataCollatorForCompletionOnlyLM,
|
||||
RunningMoments,
|
||||
disable_dropout_in_model,
|
||||
peft_module_casting_to_bf16,
|
||||
)
|
||||
from typing import TYPE_CHECKING
|
||||
from ..import_utils import _LazyModule, is_diffusers_available, OptionalDependencyNotAvailable
|
||||
|
||||
# isort: on
|
||||
_import_structure = {
|
||||
"utils": [
|
||||
"AdaptiveKLController",
|
||||
"FixedKLController",
|
||||
"ConstantLengthDataset",
|
||||
"DataCollatorForCompletionOnlyLM",
|
||||
"RunningMoments",
|
||||
"disable_dropout_in_model",
|
||||
"peft_module_casting_to_bf16",
|
||||
"RichProgressCallback",
|
||||
],
|
||||
"dpo_trainer": [
|
||||
"DPOTrainer",
|
||||
],
|
||||
"cpo_config": ["CPOConfig"],
|
||||
"cpo_trainer": ["CPOTrainer"],
|
||||
"iterative_sft_trainer": [
|
||||
"IterativeSFTTrainer",
|
||||
],
|
||||
"kto_config": ["KTOConfig"],
|
||||
"kto_trainer": ["KTOTrainer"],
|
||||
"model_config": ["ModelConfig"],
|
||||
"orpo_config": ["ORPOConfig"],
|
||||
"orpo_trainer": ["ORPOTrainer"],
|
||||
"ppo_config": ["PPOConfig"],
|
||||
"ppo_trainer": ["PPOTrainer"],
|
||||
"reward_config": ["RewardConfig"],
|
||||
"reward_trainer": ["RewardTrainer", "compute_accuracy"],
|
||||
"sft_trainer": ["SFTTrainer"],
|
||||
"base": ["BaseTrainer"],
|
||||
"ddpo_config": ["DDPOConfig"],
|
||||
}
|
||||
|
||||
from ..import_utils import is_diffusers_available
|
||||
from .base import BaseTrainer
|
||||
from .ddpo_config import DDPOConfig
|
||||
try:
|
||||
if not is_diffusers_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
_import_structure["ddpo_trainer"] = ["DDPOTrainer"]
|
||||
|
||||
|
||||
if is_diffusers_available():
|
||||
from .ddpo_trainer import DDPOTrainer
|
||||
if TYPE_CHECKING:
|
||||
# isort: off
|
||||
from .utils import (
|
||||
AdaptiveKLController,
|
||||
FixedKLController,
|
||||
ConstantLengthDataset,
|
||||
DataCollatorForCompletionOnlyLM,
|
||||
RunningMoments,
|
||||
disable_dropout_in_model,
|
||||
peft_module_casting_to_bf16,
|
||||
RichProgressCallback,
|
||||
)
|
||||
|
||||
from .dpo_trainer import DPOTrainer
|
||||
from .iterative_sft_trainer import IterativeSFTTrainer
|
||||
from .model_config import ModelConfig
|
||||
from .ppo_config import PPOConfig
|
||||
from .ppo_trainer import PPOTrainer
|
||||
from .reward_config import RewardConfig
|
||||
from .reward_trainer import RewardTrainer, compute_accuracy
|
||||
from .sft_trainer import SFTTrainer
|
||||
# isort: on
|
||||
|
||||
from .base import BaseTrainer
|
||||
from .ddpo_config import DDPOConfig
|
||||
|
||||
from .dpo_trainer import DPOTrainer
|
||||
from .iterative_sft_trainer import IterativeSFTTrainer
|
||||
from .cpo_config import CPOConfig
|
||||
from .cpo_trainer import CPOTrainer
|
||||
from .kto_config import KTOConfig
|
||||
from .kto_trainer import KTOTrainer
|
||||
from .model_config import ModelConfig
|
||||
from .orpo_config import ORPOConfig
|
||||
from .orpo_trainer import ORPOTrainer
|
||||
from .ppo_config import PPOConfig
|
||||
from .ppo_trainer import PPOTrainer
|
||||
from .reward_config import RewardConfig
|
||||
from .reward_trainer import RewardTrainer, compute_accuracy
|
||||
from .sft_trainer import SFTTrainer
|
||||
|
||||
try:
|
||||
if not is_diffusers_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
pass
|
||||
else:
|
||||
from .ddpo_trainer import DDPOTrainer
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
|
||||
|
78
trl/trainer/cpo_config.py
Normal file
78
trl/trainer/cpo_config.py
Normal file
@ -0,0 +1,78 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Literal, Optional
|
||||
|
||||
from transformers import TrainingArguments
|
||||
|
||||
|
||||
@dataclass
|
||||
class CPOConfig(TrainingArguments):
|
||||
r"""
|
||||
CPOConfig collects all training arguments related to the [`CPOTrainer`] class.
|
||||
|
||||
Using [`HfArgumentParser`] we can turn this class into
|
||||
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
||||
command line.
|
||||
|
||||
Parameters:
|
||||
max_length (`int`, defaults to `None`):
|
||||
The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator.
|
||||
max_prompt_length (`int`, defaults to `None`):
|
||||
The maximum length of the prompt. This argument is required if you want to use the default data collator.
|
||||
max_target_length (`int`, defaults to `None`):
|
||||
The maximum length of the target. This argument is required if you want to use the default data collator and your model is an encoder-decoder.
|
||||
beta (`float`, defaults to 0.1):
|
||||
The beta factor in CPO loss.
|
||||
label_smoothing (`float`, defaults to 0):
|
||||
The label smoothing factor. This argument is required if you want to use the default data collator.
|
||||
loss_type (`str`, defaults to `sigmoid`):
|
||||
The type of loss to use. This argument is required if you want to use the default data collator.
|
||||
label_pad_token_id (`int`, defaults to `-100`):
|
||||
The label pad token id. This argument is required if you want to use the default data collator.
|
||||
padding_value (`int`, defaults to `None`):
|
||||
The padding value if it is different to the tokenizer's pad_token_id.
|
||||
truncation_mode (`str`, defaults to `keep_end`):
|
||||
The truncation mode to use, either `keep_end` or `keep_start`. This argument is required if you want to use the default data collator.
|
||||
generate_during_eval (`bool`, defaults to `False`):
|
||||
Whether to sample and log generations during evaluation step.
|
||||
is_encoder_decoder (`Optional[bool]`, `optional`, defaults to `None`):
|
||||
If no model is provided, we need to know if the model_init returns an encoder-decoder.
|
||||
disable_dropout (`bool`, defaults to `True`):
|
||||
Whether or not to disable dropouts in `model`.
|
||||
model_init_kwargs (`Optional[Dict]`, *optional*):
|
||||
Dict of Optional kwargs to pass when instantiating the model from a string
|
||||
dataset_num_proc (`Optional[int]`, *optional*):
|
||||
The number of workers to use to tokenize the data. Defaults to None.
|
||||
"""
|
||||
|
||||
max_length: Optional[int] = None
|
||||
max_prompt_length: Optional[int] = None
|
||||
max_completion_length: Optional[int] = None
|
||||
max_target_length: Optional[int] = None
|
||||
|
||||
beta: float = 0.1
|
||||
label_smoothing: float = 0
|
||||
loss_type: Literal["sigmoid", "hinge", "ipo", "kto_pair"] = "sigmoid"
|
||||
disable_dropout: bool = True
|
||||
|
||||
label_pad_token_id: int = -100
|
||||
padding_value: int = None
|
||||
truncation_mode: str = "keep_end"
|
||||
generate_during_eval: bool = False
|
||||
is_encoder_decoder: Optional[bool] = None
|
||||
|
||||
model_init_kwargs: Optional[Dict] = None
|
||||
|
||||
dataset_num_proc: Optional[int] = None
|
929
trl/trainer/cpo_trainer.py
Normal file
929
trl/trainer/cpo_trainer.py
Normal file
@ -0,0 +1,929 @@
|
||||
# CPO Authors: Haoran Xu, Amr Sharaf, Yunmo Chen, Weiting Tan, Lingfeng Shen, Benjamin Van Durme, Kenton Murray, Young Jin Kim
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import random
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from contextlib import nullcontext
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from accelerate import PartialState
|
||||
from datasets import Dataset
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import AutoModelForCausalLM, DataCollator, PreTrainedModel, PreTrainedTokenizerBase, Trainer
|
||||
from transformers.trainer_callback import TrainerCallback
|
||||
from transformers.trainer_utils import EvalLoopOutput
|
||||
from transformers.utils import is_torch_fx_proxy
|
||||
|
||||
from ..import_utils import is_peft_available, is_wandb_available
|
||||
from .cpo_config import CPOConfig
|
||||
from .utils import (
|
||||
DPODataCollatorWithPadding,
|
||||
disable_dropout_in_model,
|
||||
pad_to_length,
|
||||
peft_module_casting_to_bf16,
|
||||
trl_sanitze_kwargs_for_tagging,
|
||||
)
|
||||
|
||||
|
||||
if is_peft_available():
|
||||
from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training
|
||||
|
||||
|
||||
if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
|
||||
class CPOTrainer(Trainer):
|
||||
r"""
|
||||
Initialize CPOTrainer.
|
||||
|
||||
Args:
|
||||
model (`transformers.PreTrainedModel`):
|
||||
The model to train, preferably an `AutoModelForSequenceClassification`.
|
||||
args (`CPOConfig`):
|
||||
The CPO config arguments to use for training.
|
||||
data_collator (`transformers.DataCollator`):
|
||||
The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
|
||||
which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
|
||||
train_dataset (`datasets.Dataset`):
|
||||
The dataset to use for training.
|
||||
eval_dataset (`datasets.Dataset`):
|
||||
The dataset to use for evaluation.
|
||||
tokenizer (`transformers.PreTrainedTokenizerBase`):
|
||||
The tokenizer to use for training. This argument is required if you want to use the default data collator.
|
||||
model_init (`Callable[[], transformers.PreTrainedModel]`):
|
||||
The model initializer to use for training. If None is specified, the default model initializer will be used.
|
||||
callbacks (`List[transformers.TrainerCallback]`):
|
||||
The callbacks to use for training.
|
||||
optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
|
||||
The optimizer and scheduler to use for training.
|
||||
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
|
||||
The function to use to preprocess the logits before computing the metrics.
|
||||
peft_config (`Dict`, defaults to `None`):
|
||||
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
|
||||
compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*):
|
||||
The function to use to compute the metrics. Must take a `EvalPrediction` and return
|
||||
a dictionary string to metric values.
|
||||
"""
|
||||
|
||||
_tag_names = ["trl", "cpo"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
|
||||
args: Optional[CPOConfig] = None,
|
||||
data_collator: Optional[DataCollator] = None,
|
||||
train_dataset: Optional[Dataset] = None,
|
||||
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
|
||||
tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
||||
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
||||
callbacks: Optional[List[TrainerCallback]] = None,
|
||||
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
||||
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
||||
peft_config: Optional[Dict] = None,
|
||||
compute_metrics: Optional[Callable[[EvalLoopOutput], Dict]] = None,
|
||||
):
|
||||
if args.model_init_kwargs is None:
|
||||
model_init_kwargs = {}
|
||||
elif not isinstance(model, str):
|
||||
raise ValueError("You passed model_kwargs to the CPOTrainer. But your model is already instantiated.")
|
||||
else:
|
||||
model_init_kwargs = args.model_init_kwargs
|
||||
model_init_kwargs["torch_dtype"] = (
|
||||
model_init_kwargs["torch_dtype"]
|
||||
if model_init_kwargs["torch_dtype"] in ["auto", None]
|
||||
else getattr(torch, model_init_kwargs["torch_dtype"])
|
||||
)
|
||||
|
||||
if isinstance(model, str):
|
||||
warnings.warn(
|
||||
"You passed a model_id to the CPOTrainer. This will automatically create an "
|
||||
"`AutoModelForCausalLM` or a `PeftModel` (if you passed a `peft_config`) for you."
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
|
||||
|
||||
# Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
|
||||
# has been called in order to properly call autocast if needed.
|
||||
self._peft_has_been_casted_to_bf16 = False
|
||||
|
||||
if not is_peft_available() and peft_config is not None:
|
||||
raise ValueError(
|
||||
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
|
||||
)
|
||||
elif is_peft_available() and peft_config is not None:
|
||||
# if model is a peft model and we have a peft_config, we merge and unload it first
|
||||
if isinstance(model, PeftModel):
|
||||
model = model.merge_and_unload()
|
||||
|
||||
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
|
||||
_support_gc_kwargs = hasattr(
|
||||
args, "gradient_checkpointing_kwargs"
|
||||
) and "gradient_checkpointing_kwargs" in list(
|
||||
inspect.signature(prepare_model_for_kbit_training).parameters
|
||||
)
|
||||
|
||||
prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
|
||||
|
||||
if _support_gc_kwargs:
|
||||
prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
|
||||
|
||||
model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
|
||||
elif getattr(args, "gradient_checkpointing", False):
|
||||
# For backward compatibility with older versions of transformers
|
||||
if hasattr(model, "enable_input_require_grads"):
|
||||
model.enable_input_require_grads()
|
||||
else:
|
||||
|
||||
def make_inputs_require_grad(module, input, output):
|
||||
output.requires_grad_(True)
|
||||
|
||||
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
||||
|
||||
# get peft model with the given config
|
||||
model = get_peft_model(model, peft_config)
|
||||
if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
|
||||
peft_module_casting_to_bf16(model)
|
||||
# If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
|
||||
self._peft_has_been_casted_to_bf16 = True
|
||||
|
||||
# For models that use gradient_checkpointing, we need to attach a hook that enables input
|
||||
# to explicitly have `requires_grad=True`, otherwise training will either silently
|
||||
# fail or completely fail.
|
||||
elif getattr(args, "gradient_checkpointing", False):
|
||||
# For backward compatibility with older versions of transformers
|
||||
if hasattr(model, "enable_input_require_grads"):
|
||||
model.enable_input_require_grads()
|
||||
else:
|
||||
|
||||
def make_inputs_require_grad(module, input, output):
|
||||
output.requires_grad_(True)
|
||||
|
||||
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
||||
|
||||
if args.generate_during_eval and not is_wandb_available():
|
||||
raise ValueError(
|
||||
"`generate_during_eval=True` requires Weights and Biases to be installed."
|
||||
" Please install `wandb` to resolve."
|
||||
)
|
||||
|
||||
if model is not None:
|
||||
self.is_encoder_decoder = model.config.is_encoder_decoder
|
||||
elif args.is_encoder_decoder is None:
|
||||
raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
|
||||
else:
|
||||
self.is_encoder_decoder = args.is_encoder_decoder
|
||||
|
||||
if self.is_encoder_decoder:
|
||||
self.decoder_start_token_id = model.config.decoder_start_token_id
|
||||
self.pad_token_id = model.config.pad_token_id
|
||||
|
||||
if tokenizer is None:
|
||||
raise ValueError("tokenizer must be specified to tokenize a CPO dataset.")
|
||||
if args.max_length is None:
|
||||
warnings.warn(
|
||||
"`max_length` is not set in the CPOConfig's init"
|
||||
" it will default to `512` by default, but you should do it yourself in the future.",
|
||||
UserWarning,
|
||||
)
|
||||
max_length = 512
|
||||
else:
|
||||
max_length = args.max_length
|
||||
if args.max_prompt_length is None:
|
||||
warnings.warn(
|
||||
"`max_prompt_length` is not set in the CPOConfig's init"
|
||||
" it will default to `128` by default, but you should do it yourself in the future.",
|
||||
UserWarning,
|
||||
)
|
||||
max_prompt_length = 128
|
||||
else:
|
||||
max_prompt_length = args.max_prompt_length
|
||||
|
||||
if args.max_target_length is None and self.is_encoder_decoder:
|
||||
warnings.warn(
|
||||
"When using an encoder decoder architecture, you should set `max_target_length` in the CPOConfig's init"
|
||||
" it will default to `128` by default, but you should do it yourself in the future.",
|
||||
UserWarning,
|
||||
)
|
||||
max_target_length = 128
|
||||
else:
|
||||
max_target_length = args.max_target_length
|
||||
|
||||
if data_collator is None:
|
||||
data_collator = DPODataCollatorWithPadding(
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
label_pad_token_id=args.label_pad_token_id,
|
||||
is_encoder_decoder=self.is_encoder_decoder,
|
||||
)
|
||||
|
||||
if args.remove_unused_columns:
|
||||
args.remove_unused_columns = False
|
||||
# warn users
|
||||
warnings.warn(
|
||||
"When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments"
|
||||
" we have set it for you, but you should do it yourself in the future.",
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
self.use_dpo_data_collator = True
|
||||
else:
|
||||
self.use_dpo_data_collator = False
|
||||
|
||||
if args.disable_dropout:
|
||||
disable_dropout_in_model(model)
|
||||
|
||||
self.max_length = max_length
|
||||
self.generate_during_eval = args.generate_during_eval
|
||||
self.label_pad_token_id = args.label_pad_token_id
|
||||
self.padding_value = args.padding_value if args.padding_value is not None else tokenizer.pad_token_id
|
||||
self.max_prompt_length = max_prompt_length
|
||||
self.truncation_mode = args.truncation_mode
|
||||
self.max_target_length = max_target_length
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
if args.loss_type in ["hinge", "ipo", "kto_pair"] and args.label_smoothing > 0:
|
||||
warnings.warn(
|
||||
"You are using a loss type that does not support label smoothing. Ignoring label_smoothing parameter."
|
||||
)
|
||||
|
||||
self.beta = args.beta
|
||||
self.label_smoothing = args.label_smoothing
|
||||
self.loss_type = args.loss_type
|
||||
|
||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||
|
||||
# Compute that only on the main process for faster data processing.
|
||||
# see: https://github.com/huggingface/trl/pull/1255
|
||||
with PartialState().local_main_process_first():
|
||||
# tokenize the dataset
|
||||
train_dataset = train_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
|
||||
if eval_dataset is not None:
|
||||
eval_dataset = eval_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
|
||||
|
||||
super().__init__(
|
||||
model=model,
|
||||
args=args,
|
||||
data_collator=data_collator,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
tokenizer=tokenizer,
|
||||
model_init=model_init,
|
||||
compute_metrics=compute_metrics,
|
||||
callbacks=callbacks,
|
||||
optimizers=optimizers,
|
||||
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
||||
)
|
||||
|
||||
# Add tags for models that have been loaded with the correct transformers version
|
||||
if hasattr(self.model, "add_model_tags"):
|
||||
self.model.add_model_tags(self._tag_names)
|
||||
|
||||
if not hasattr(self, "accelerator"):
|
||||
raise AttributeError(
|
||||
"Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
|
||||
)
|
||||
|
||||
def build_tokenized_answer(self, prompt, answer):
|
||||
"""
|
||||
Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`.
|
||||
It does ensure `enc(a + b) = enc(a) + enc(a + b)[len(enc(a)):]`.
|
||||
Reference:
|
||||
https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
|
||||
"""
|
||||
|
||||
full_tokenized = self.tokenizer(prompt + answer, add_special_tokens=False)
|
||||
prompt_input_ids = self.tokenizer(prompt, add_special_tokens=False)["input_ids"]
|
||||
|
||||
answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids) :]
|
||||
answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids) :]
|
||||
|
||||
# Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]`
|
||||
full_concat_input_ids = np.concatenate([prompt_input_ids, answer_input_ids])
|
||||
|
||||
# Prepare input tokens for token by token comparison
|
||||
full_input_ids = np.array(full_tokenized["input_ids"])
|
||||
|
||||
if len(full_input_ids) != len(full_concat_input_ids):
|
||||
raise ValueError("Prompt input ids and answer input ids should have the same length.")
|
||||
|
||||
# On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens
|
||||
# can be merged together when tokenizing prompt+answer. This could result
|
||||
# on the last token from the prompt being different when tokenized on its own
|
||||
# vs when done as prompt+answer.
|
||||
response_token_ids_start_idx = len(prompt_input_ids)
|
||||
|
||||
# If tokenized prompt is different than both prompt+answer, then it means the
|
||||
# last token has changed due to merging.
|
||||
if prompt_input_ids != full_tokenized["input_ids"][:response_token_ids_start_idx]:
|
||||
response_token_ids_start_idx -= 1
|
||||
|
||||
prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx]
|
||||
prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx]
|
||||
|
||||
if len(prompt_input_ids) != len(prompt_attention_mask):
|
||||
raise ValueError("Prompt input ids and attention mask should have the same length.")
|
||||
|
||||
answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:]
|
||||
answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:]
|
||||
|
||||
return dict(
|
||||
prompt_input_ids=prompt_input_ids,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
input_ids=answer_input_ids,
|
||||
attention_mask=answer_attention_mask,
|
||||
)
|
||||
|
||||
def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module]] = None) -> Dict:
|
||||
"""Tokenize a single row from a CPO specific dataset.
|
||||
|
||||
At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation
|
||||
in case the prompt + chosen or prompt + rejected responses is/are too long. First
|
||||
we truncate the prompt; if we're still too long, we truncate the chosen/rejected.
|
||||
|
||||
We also create the labels for the chosen/rejected responses, which are of length equal to
|
||||
the sum of the length of the prompt and the chosen/rejected response, with
|
||||
label_pad_token_id for the prompt tokens.
|
||||
"""
|
||||
batch = {}
|
||||
prompt = feature["prompt"]
|
||||
chosen = feature["chosen"]
|
||||
rejected = feature["rejected"]
|
||||
|
||||
if not self.is_encoder_decoder:
|
||||
# Check issues below for more details
|
||||
# 1. https://github.com/huggingface/trl/issues/907
|
||||
# 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
|
||||
# 3. https://github.com/LianjiaTech/BELLE/issues/337
|
||||
|
||||
if not isinstance(prompt, str):
|
||||
raise ValueError(f"prompt should be an str but got {type(prompt)}")
|
||||
prompt_tokens = self.tokenizer(prompt, add_special_tokens=False)
|
||||
prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()}
|
||||
|
||||
if not isinstance(chosen, str):
|
||||
raise ValueError(f"chosen should be an str but got {type(chosen)}")
|
||||
chosen_tokens = self.build_tokenized_answer(prompt, chosen)
|
||||
|
||||
if not isinstance(rejected, str):
|
||||
raise ValueError(f"rejected should be an str but got {type(rejected)}")
|
||||
rejected_tokens = self.build_tokenized_answer(prompt, rejected)
|
||||
|
||||
# Last prompt token might get merged by tokenizer and
|
||||
# it should not be included for generation if that happens
|
||||
prompt_len_input_ids = len(prompt_tokens["prompt_input_ids"])
|
||||
|
||||
chosen_prompt_len_input_ids = len(chosen_tokens["prompt_input_ids"])
|
||||
rejected_prompt_len_input_ids = len(rejected_tokens["prompt_input_ids"])
|
||||
prompt_len_input_ids = min(chosen_prompt_len_input_ids, rejected_prompt_len_input_ids)
|
||||
|
||||
for k, v in prompt_tokens.items():
|
||||
prompt_tokens[k] = v[:prompt_len_input_ids]
|
||||
|
||||
# Make sure prompts only have one different token at most an
|
||||
# and length only differs by 1 at most
|
||||
num_diff_tokens = sum(
|
||||
[a != b for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"])]
|
||||
)
|
||||
num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids)
|
||||
if num_diff_tokens > 1 or num_diff_len > 1:
|
||||
raise ValueError(
|
||||
"Chosen and rejected prompt_input_ids might only differ on the "
|
||||
"last token due to tokenizer merge ops."
|
||||
)
|
||||
|
||||
# add BOS token to head of prompt
|
||||
prompt_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + prompt_tokens["prompt_input_ids"]
|
||||
chosen_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + chosen_tokens["prompt_input_ids"]
|
||||
rejected_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + rejected_tokens["prompt_input_ids"]
|
||||
|
||||
prompt_tokens["prompt_attention_mask"] = [1] + prompt_tokens["prompt_attention_mask"]
|
||||
chosen_tokens["prompt_attention_mask"] = [1] + chosen_tokens["prompt_attention_mask"]
|
||||
rejected_tokens["prompt_attention_mask"] = [1] + rejected_tokens["prompt_attention_mask"]
|
||||
|
||||
# add EOS token to end of answer
|
||||
chosen_tokens["input_ids"].append(self.tokenizer.eos_token_id)
|
||||
chosen_tokens["attention_mask"].append(1)
|
||||
|
||||
rejected_tokens["input_ids"].append(self.tokenizer.eos_token_id)
|
||||
rejected_tokens["attention_mask"].append(1)
|
||||
|
||||
longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))
|
||||
|
||||
# if combined sequence is too long, truncate the prompt
|
||||
for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]:
|
||||
if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
|
||||
if self.truncation_mode == "keep_start":
|
||||
for k in ["prompt_input_ids", "prompt_attention_mask"]:
|
||||
answer_tokens[k] = answer_tokens[k][: self.max_prompt_length]
|
||||
elif self.truncation_mode == "keep_end":
|
||||
for k in ["prompt_input_ids", "prompt_attention_mask"]:
|
||||
answer_tokens[k] = answer_tokens[k][-self.max_prompt_length :]
|
||||
else:
|
||||
raise ValueError(f"Unknown truncation mode: {self.truncation_mode}")
|
||||
|
||||
# if that's still too long, truncate the response
|
||||
for answer_tokens in [chosen_tokens, rejected_tokens]:
|
||||
if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
|
||||
for k in ["input_ids", "attention_mask"]:
|
||||
answer_tokens[k] = answer_tokens[k][: self.max_length - self.max_prompt_length]
|
||||
|
||||
# Create labels
|
||||
chosen_sequence_tokens = {
|
||||
k: chosen_tokens[f"prompt_{k}"] + chosen_tokens[k] for k in ["input_ids", "attention_mask"]
|
||||
}
|
||||
rejected_sequence_tokens = {
|
||||
k: rejected_tokens[f"prompt_{k}"] + rejected_tokens[k] for k in ["input_ids", "attention_mask"]
|
||||
}
|
||||
chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:]
|
||||
chosen_sequence_tokens["labels"][: len(chosen_tokens["prompt_input_ids"])] = [
|
||||
self.label_pad_token_id
|
||||
] * len(chosen_tokens["prompt_input_ids"])
|
||||
rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:]
|
||||
rejected_sequence_tokens["labels"][: len(rejected_tokens["prompt_input_ids"])] = [
|
||||
self.label_pad_token_id
|
||||
] * len(rejected_tokens["prompt_input_ids"])
|
||||
|
||||
for k, toks in {
|
||||
"chosen_": chosen_sequence_tokens,
|
||||
"rejected_": rejected_sequence_tokens,
|
||||
"": prompt_tokens,
|
||||
}.items():
|
||||
for type_key, tokens in toks.items():
|
||||
if type_key == "token_type_ids":
|
||||
continue
|
||||
batch[f"{k}{type_key}"] = tokens
|
||||
|
||||
else:
|
||||
chosen_tokens = self.tokenizer(
|
||||
chosen, truncation=True, max_length=self.max_target_length, add_special_tokens=True
|
||||
)
|
||||
rejected_tokens = self.tokenizer(
|
||||
rejected, truncation=True, max_length=self.max_target_length, add_special_tokens=True
|
||||
)
|
||||
prompt_tokens = self.tokenizer(
|
||||
prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True
|
||||
)
|
||||
|
||||
batch["chosen_labels"] = chosen_tokens["input_ids"]
|
||||
batch["rejected_labels"] = rejected_tokens["input_ids"]
|
||||
batch["prompt_input_ids"] = prompt_tokens["input_ids"]
|
||||
batch["prompt_attention_mask"] = prompt_tokens["attention_mask"]
|
||||
|
||||
if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"):
|
||||
batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
|
||||
labels=torch.tensor(batch["rejected_labels"])
|
||||
)
|
||||
batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
|
||||
labels=torch.tensor(batch["chosen_labels"])
|
||||
)
|
||||
|
||||
return batch
|
||||
|
||||
@staticmethod
|
||||
def concatenated_inputs(
|
||||
batch: Dict[str, Union[List, torch.LongTensor]],
|
||||
is_encoder_decoder: bool = False,
|
||||
label_pad_token_id: int = -100,
|
||||
padding_value: int = 0,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> Dict[str, torch.LongTensor]:
|
||||
"""Concatenate the chosen and rejected inputs into a single tensor.
|
||||
|
||||
Args:
|
||||
batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length).
|
||||
is_encoder_decoder: Whether the model is an encoder-decoder model.
|
||||
label_pad_token_id: The label pad token id.
|
||||
padding_value: The padding value to use for the concatenated inputs_ids.
|
||||
device: The device for the concatenated inputs.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'.
|
||||
"""
|
||||
concatenated_batch = {}
|
||||
|
||||
if is_encoder_decoder:
|
||||
max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1])
|
||||
else:
|
||||
max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1])
|
||||
|
||||
for k in batch:
|
||||
if k.startswith("chosen") and isinstance(batch[k], torch.Tensor):
|
||||
if "labels" in k or is_encoder_decoder:
|
||||
pad_value = label_pad_token_id
|
||||
elif k.endswith("_input_ids"):
|
||||
pad_value = padding_value
|
||||
elif k.endswith("_attention_mask"):
|
||||
pad_value = 0
|
||||
concatenated_key = k.replace("chosen", "concatenated")
|
||||
concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value)
|
||||
for k in batch:
|
||||
if k.startswith("rejected") and isinstance(batch[k], torch.Tensor):
|
||||
if "labels" in k or is_encoder_decoder:
|
||||
pad_value = label_pad_token_id
|
||||
elif k.endswith("_input_ids"):
|
||||
pad_value = padding_value
|
||||
elif k.endswith("_attention_mask"):
|
||||
pad_value = 0
|
||||
concatenated_key = k.replace("rejected", "concatenated")
|
||||
concatenated_batch[concatenated_key] = torch.cat(
|
||||
(
|
||||
concatenated_batch[concatenated_key],
|
||||
pad_to_length(batch[k], max_length, pad_value=pad_value),
|
||||
),
|
||||
dim=0,
|
||||
).to(device=device)
|
||||
|
||||
if is_encoder_decoder:
|
||||
concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1).to(device=device)
|
||||
concatenated_batch["concatenated_attention_mask"] = (
|
||||
batch["prompt_attention_mask"].repeat(2, 1).to(device=device)
|
||||
)
|
||||
|
||||
return concatenated_batch
|
||||
|
||||
def cpo_loss(
|
||||
self,
|
||||
policy_chosen_logps: torch.FloatTensor,
|
||||
policy_rejected_logps: torch.FloatTensor,
|
||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
||||
"""Compute the CPO loss for a batch of policy and reference model log probabilities.
|
||||
|
||||
Args:
|
||||
policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
|
||||
policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
|
||||
|
||||
Returns:
|
||||
A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
|
||||
The losses tensor contains the CPO loss for each example in the batch.
|
||||
The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
|
||||
"""
|
||||
logits = (policy_chosen_logps - policy_rejected_logps).to(self.accelerator.device)
|
||||
|
||||
# The beta is a temperature parameter for the CPO loss, typically something in the range of 0.1 to 0.5.
|
||||
# We ignore the reference model as beta -> 0. The label_smoothing parameter encodes our uncertainty about the labels and
|
||||
# calculates a conservative CPO loss.
|
||||
if self.loss_type == "sigmoid":
|
||||
# This reduces to Equation 3 from the CPO paper when label_smoothing -> 0.
|
||||
losses = (
|
||||
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
|
||||
- F.logsigmoid(-self.beta * logits) * self.label_smoothing
|
||||
)
|
||||
elif self.loss_type == "hinge":
|
||||
losses = torch.relu(1 - self.beta * logits)
|
||||
elif self.loss_type == "ipo":
|
||||
# eqn (17) of the paper where beta is the regularization parameter for the IPO loss, denoted by tau in the paper.
|
||||
losses = (logits - 1 / (2 * self.beta)) ** 2
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid', 'hinge', 'ipo', 'kto_pair']"
|
||||
)
|
||||
|
||||
chosen_rewards = self.beta * (policy_chosen_logps.to(self.accelerator.device)).detach()
|
||||
rejected_rewards = self.beta * (policy_rejected_logps.to(self.accelerator.device)).detach()
|
||||
|
||||
return losses, chosen_rewards, rejected_rewards
|
||||
|
||||
@staticmethod
|
||||
def get_batch_logps(
|
||||
logits: torch.FloatTensor,
|
||||
labels: torch.LongTensor,
|
||||
average_log_prob: bool = False,
|
||||
label_pad_token_id: int = -100,
|
||||
is_encoder_decoder: bool = False,
|
||||
) -> torch.FloatTensor:
|
||||
"""Compute the log probabilities of the given labels under the given logits.
|
||||
|
||||
Args:
|
||||
logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
|
||||
labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
|
||||
average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
|
||||
label_pad_token_id: The label pad token id.
|
||||
is_encoder_decoder: Whether the model is an encoder-decoder model.
|
||||
|
||||
Returns:
|
||||
A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
|
||||
"""
|
||||
if logits.shape[:-1] != labels.shape:
|
||||
raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
|
||||
|
||||
if not is_encoder_decoder:
|
||||
labels = labels[:, 1:].clone()
|
||||
logits = logits[:, :-1, :]
|
||||
loss_mask = labels != label_pad_token_id
|
||||
|
||||
# dummy token; we'll ignore the losses on these tokens later
|
||||
labels[labels == label_pad_token_id] = 0
|
||||
|
||||
per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
|
||||
|
||||
if average_log_prob:
|
||||
return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
|
||||
else:
|
||||
return (per_token_logps * loss_mask).sum(-1)
|
||||
|
||||
def concatenated_forward(
|
||||
self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
|
||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
||||
"""Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
|
||||
|
||||
We do this to avoid doing two forward passes, because it's faster for FSDP.
|
||||
"""
|
||||
concatenated_batch = self.concatenated_inputs(
|
||||
batch,
|
||||
is_encoder_decoder=self.is_encoder_decoder,
|
||||
label_pad_token_id=self.label_pad_token_id,
|
||||
padding_value=self.padding_value,
|
||||
device=self.accelerator.device,
|
||||
)
|
||||
len_chosen = batch["chosen_labels"].shape[0]
|
||||
|
||||
model_kwargs = (
|
||||
{
|
||||
"decoder_input_ids": self._shift_right(concatenated_batch["concatenated_labels"]),
|
||||
}
|
||||
if self.is_encoder_decoder
|
||||
else {}
|
||||
)
|
||||
|
||||
outputs = model(
|
||||
concatenated_batch["concatenated_input_ids"],
|
||||
attention_mask=concatenated_batch["concatenated_attention_mask"],
|
||||
use_cache=False,
|
||||
**model_kwargs,
|
||||
)
|
||||
all_logits = outputs.logits
|
||||
|
||||
def cross_entropy_loss(logits, labels):
|
||||
if not self.is_encoder_decoder:
|
||||
# Shift so that tokens < n predict n
|
||||
logits = logits[..., :-1, :].contiguous()
|
||||
labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = nn.CrossEntropyLoss()
|
||||
logits = logits.view(-1, logits.shape[-1])
|
||||
labels = labels.view(-1)
|
||||
# Enable model parallelism
|
||||
labels = labels.to(logits.device)
|
||||
loss = loss_fct(logits, labels)
|
||||
return loss
|
||||
|
||||
labels = concatenated_batch["concatenated_labels"].clone()
|
||||
nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])
|
||||
|
||||
all_logps = self.get_batch_logps(
|
||||
all_logits,
|
||||
concatenated_batch["concatenated_labels"],
|
||||
average_log_prob=self.loss_type == "ipo",
|
||||
is_encoder_decoder=self.is_encoder_decoder,
|
||||
label_pad_token_id=self.label_pad_token_id,
|
||||
)
|
||||
|
||||
chosen_logps = all_logps[:len_chosen]
|
||||
rejected_logps = all_logps[len_chosen:]
|
||||
|
||||
chosen_logits = all_logits[:len_chosen]
|
||||
rejected_logits = all_logits[len_chosen:]
|
||||
|
||||
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, nll_loss)
|
||||
|
||||
def get_batch_loss_metrics(
|
||||
self,
|
||||
model,
|
||||
batch: Dict[str, Union[List, torch.LongTensor]],
|
||||
train_eval: Literal["train", "eval"] = "train",
|
||||
):
|
||||
"""Compute the CPO loss and other metrics for the given batch of inputs for train or test."""
|
||||
metrics = {}
|
||||
|
||||
(
|
||||
policy_chosen_logps,
|
||||
policy_rejected_logps,
|
||||
policy_chosen_logits,
|
||||
policy_rejected_logits,
|
||||
policy_nll_loss,
|
||||
) = self.concatenated_forward(model, batch)
|
||||
|
||||
losses, chosen_rewards, rejected_rewards = self.cpo_loss(
|
||||
policy_chosen_logps,
|
||||
policy_rejected_logps,
|
||||
)
|
||||
|
||||
loss = losses.mean() + policy_nll_loss
|
||||
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
||||
|
||||
prefix = "eval_" if train_eval == "eval" else ""
|
||||
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu()
|
||||
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().cpu()
|
||||
metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean().cpu()
|
||||
metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean().cpu()
|
||||
metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean().cpu()
|
||||
metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean().cpu()
|
||||
metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean().cpu()
|
||||
metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean().cpu()
|
||||
metrics[f"{prefix}nll_loss"] = policy_nll_loss.detach().mean().cpu()
|
||||
|
||||
return loss, metrics
|
||||
|
||||
def compute_loss(
|
||||
self,
|
||||
model: Union[PreTrainedModel, nn.Module],
|
||||
inputs: Dict[str, Union[torch.Tensor, Any]],
|
||||
return_outputs=False,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
|
||||
if not self.use_dpo_data_collator:
|
||||
warnings.warn(
|
||||
"compute_loss is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
|
||||
"DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
|
||||
)
|
||||
|
||||
compute_loss_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext
|
||||
|
||||
with compute_loss_context_manager():
|
||||
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
|
||||
|
||||
# force log the metrics
|
||||
self.store_metrics(metrics, train_eval="train")
|
||||
|
||||
if return_outputs:
|
||||
return (loss, metrics)
|
||||
return loss
|
||||
|
||||
def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]:
|
||||
"""Generate samples from the model and reference model for the given batch of inputs."""
|
||||
|
||||
# If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
|
||||
# the torch cuda amp context manager as some hidden states are silently casted to full precision.
|
||||
generate_context_manager = nullcontext if not self._peft_has_been_casted_to_bf16 else torch.cuda.amp.autocast
|
||||
|
||||
with generate_context_manager():
|
||||
policy_output = model.generate(
|
||||
input_ids=batch["prompt_input_ids"],
|
||||
attention_mask=batch["prompt_attention_mask"],
|
||||
max_length=self.max_length,
|
||||
do_sample=True,
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
)
|
||||
|
||||
policy_output = pad_to_length(policy_output, self.max_length, self.tokenizer.pad_token_id)
|
||||
policy_output_decoded = self.tokenizer.batch_decode(policy_output, skip_special_tokens=True)
|
||||
|
||||
return policy_output_decoded
|
||||
|
||||
def prediction_step(
|
||||
self,
|
||||
model: Union[PreTrainedModel, nn.Module],
|
||||
inputs: Dict[str, Union[torch.Tensor, Any]],
|
||||
prediction_loss_only: bool,
|
||||
ignore_keys: Optional[List[str]] = None,
|
||||
):
|
||||
if not self.use_dpo_data_collator:
|
||||
warnings.warn(
|
||||
"prediction_step is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
|
||||
"DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
|
||||
)
|
||||
if ignore_keys is None:
|
||||
if hasattr(model, "config"):
|
||||
ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
|
||||
else:
|
||||
ignore_keys = []
|
||||
|
||||
prediction_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext
|
||||
|
||||
with torch.no_grad(), prediction_context_manager():
|
||||
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval")
|
||||
|
||||
# force log the metrics
|
||||
self.store_metrics(metrics, train_eval="eval")
|
||||
|
||||
if prediction_loss_only:
|
||||
return (loss.detach(), None, None)
|
||||
|
||||
# logits for the chosen and rejected samples from model
|
||||
logits_dict = {
|
||||
"eval_logits/chosen": metrics["eval_logits/chosen"],
|
||||
"eval_logits/rejected": metrics["eval_logits/rejected"],
|
||||
}
|
||||
logits = tuple(v.unsqueeze(dim=0) for k, v in logits_dict.items() if k not in ignore_keys)
|
||||
logits = torch.stack(logits).mean(axis=1).to(self.accelerator.device)
|
||||
labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
|
||||
|
||||
return (loss.detach(), logits, labels)
|
||||
|
||||
def store_metrics(self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
|
||||
for key, value in metrics.items():
|
||||
self._stored_metrics[train_eval][key].append(value)
|
||||
|
||||
def evaluation_loop(
|
||||
self,
|
||||
dataloader: DataLoader,
|
||||
description: str,
|
||||
prediction_loss_only: Optional[bool] = None,
|
||||
ignore_keys: Optional[List[str]] = None,
|
||||
metric_key_prefix: str = "eval",
|
||||
) -> EvalLoopOutput:
|
||||
"""
|
||||
Overriding built-in evaluation loop to store metrics for each batch.
|
||||
Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
|
||||
|
||||
Works both with or without labels.
|
||||
"""
|
||||
|
||||
# Sample and save to game log if requested (for one batch to save time)
|
||||
if self.generate_during_eval:
|
||||
# Generate random indices within the range of the total number of samples
|
||||
num_samples = len(dataloader.dataset)
|
||||
random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
|
||||
|
||||
# Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
|
||||
random_batch_dataset = dataloader.dataset.select(random_indices)
|
||||
random_batch = self.data_collator(random_batch_dataset)
|
||||
random_batch = self._prepare_inputs(random_batch)
|
||||
|
||||
policy_output_decoded = self.get_batch_samples(self.model, random_batch)
|
||||
|
||||
self.log(
|
||||
{
|
||||
"game_log": wandb.Table(
|
||||
columns=["Prompt", "Policy"],
|
||||
rows=[
|
||||
[prompt, pol[len(prompt) :]]
|
||||
for prompt, pol in zip(random_batch["prompt"], policy_output_decoded)
|
||||
],
|
||||
)
|
||||
}
|
||||
)
|
||||
self.state.log_history.pop()
|
||||
|
||||
# Base evaluation
|
||||
initial_output = super().evaluation_loop(
|
||||
dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
|
||||
)
|
||||
|
||||
return initial_output
|
||||
|
||||
def log(self, logs: Dict[str, float]) -> None:
|
||||
"""
|
||||
Log `logs` on the various objects watching training, including stored metrics.
|
||||
|
||||
Args:
|
||||
logs (`Dict[str, float]`):
|
||||
The values to log.
|
||||
"""
|
||||
# logs either has 'loss' or 'eval_loss'
|
||||
train_eval = "train" if "loss" in logs else "eval"
|
||||
# Add averaged stored metrics to logs
|
||||
for key, metrics in self._stored_metrics[train_eval].items():
|
||||
logs[key] = torch.tensor(metrics).mean().item()
|
||||
del self._stored_metrics[train_eval]
|
||||
return super().log(logs)
|
||||
|
||||
def _shift_right(self, input_ids):
|
||||
if self.decoder_start_token_id is None:
|
||||
raise ValueError(
|
||||
"model.config.decoder_start_token_id has to be defined. It is usually set to the pad_token_id."
|
||||
)
|
||||
|
||||
# shift inputs to the right
|
||||
if is_torch_fx_proxy(input_ids):
|
||||
# Item assignment is not supported natively for proxies.
|
||||
shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), self.decoder_start_token_id)
|
||||
shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
|
||||
else:
|
||||
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
||||
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
|
||||
shifted_input_ids[..., 0] = self.decoder_start_token_id
|
||||
|
||||
if self.pad_token_id is None:
|
||||
raise ValueError("model.config.pad_token_id has to be defined.")
|
||||
# replace possible -100 values in labels by `pad_token_id`
|
||||
shifted_input_ids.masked_fill_(shifted_input_ids == -100, self.pad_token_id)
|
||||
|
||||
return shifted_input_ids
|
||||
|
||||
@wraps(Trainer.push_to_hub)
|
||||
def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str:
|
||||
"""
|
||||
Overwrite the `push_to_hub` method in order to force-add the tag "cpo" when pushing the
|
||||
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
||||
"""
|
||||
kwargs = trl_sanitze_kwargs_for_tagging(model=self.model, tag_names=self._tag_names, kwargs=kwargs)
|
||||
|
||||
return super().push_to_hub(commit_message=commit_message, blocking=blocking, **kwargs)
|
@ -121,7 +121,7 @@ class DPOTrainer(Trainer):
|
||||
The function to use to compute the metrics. Must take a `EvalPrediction` and return
|
||||
a dictionary string to metric values.
|
||||
precompute_ref_log_probs (`bool`, defaults to `False`):
|
||||
Flag to precompute reference model log probabilities and evaluation datasets. This is useful if you want to train
|
||||
Flag to precompute reference model log probabilities for training and evaluation datasets. This is useful if you want to train
|
||||
without the reference model and reduce the total GPU memory needed.
|
||||
dataset_num_proc (`Optional[int]`, *optional*):
|
||||
The number of workers to use to tokenize the data. Defaults to None.
|
||||
@ -135,6 +135,8 @@ class DPOTrainer(Trainer):
|
||||
Name of the reference PEFT adapter, when using LoRA with multiple adapters.
|
||||
reference_free (`bool`):
|
||||
If True, we ignore the _provided_ reference model and implicitly use a reference model that assigns equal probability to all responses.
|
||||
force_use_ref_model (`bool`, defaults to `False`):
|
||||
In case one passes a PEFT model for the active model and you want to use a different model for the ref_model, set this flag to `True`.
|
||||
"""
|
||||
|
||||
_tag_names = ["trl", "dpo"]
|
||||
@ -173,6 +175,7 @@ class DPOTrainer(Trainer):
|
||||
model_adapter_name: Optional[str] = None,
|
||||
ref_adapter_name: Optional[str] = None,
|
||||
reference_free: bool = False,
|
||||
force_use_ref_model: bool = False,
|
||||
):
|
||||
if model_init_kwargs is None:
|
||||
model_init_kwargs = {}
|
||||
@ -213,10 +216,11 @@ class DPOTrainer(Trainer):
|
||||
if isinstance(model, PeftModel):
|
||||
model = model.merge_and_unload()
|
||||
|
||||
if ref_model is not None:
|
||||
if ref_model is not None and not force_use_ref_model:
|
||||
raise ValueError(
|
||||
"You passed both a ref_model and a peft_config. For training PEFT adapters with DPO there is no need to pass a reference"
|
||||
" model. Please pass `ref_model=None` in case you want to train PEFT adapters."
|
||||
" model. Please pass `ref_model=None` in case you want to train PEFT adapters, or pass a ref_model with `force_use_ref_model=True` in DPOTrainer's init."
|
||||
" if you want to use a different ref_model."
|
||||
)
|
||||
|
||||
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
|
||||
@ -226,12 +230,12 @@ class DPOTrainer(Trainer):
|
||||
inspect.signature(prepare_model_for_kbit_training).parameters
|
||||
)
|
||||
|
||||
preprare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
|
||||
prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
|
||||
|
||||
if _support_gc_kwargs:
|
||||
preprare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
|
||||
prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
|
||||
|
||||
model = prepare_model_for_kbit_training(model, **preprare_model_kwargs)
|
||||
model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
|
||||
elif getattr(args, "gradient_checkpointing", False):
|
||||
# For backward compatibility with older versions of transformers
|
||||
if hasattr(model, "enable_input_require_grads"):
|
||||
@ -250,7 +254,7 @@ class DPOTrainer(Trainer):
|
||||
# If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
|
||||
self._peft_has_been_casted_to_bf16 = True
|
||||
|
||||
# For models that use gradient_checkpoiting, we need to attach a hook that enables input
|
||||
# For models that use gradient_checkpointing, we need to attach a hook that enables input
|
||||
# to explicitly have `requires_grad=True`, otherwise training will either silently
|
||||
# fail or completely fail.
|
||||
elif getattr(args, "gradient_checkpointing", False):
|
||||
@ -731,10 +735,10 @@ class DPOTrainer(Trainer):
|
||||
|
||||
if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"):
|
||||
batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
|
||||
labels=batch["rejected_labels"]
|
||||
labels=torch.tensor(batch["rejected_labels"])
|
||||
)
|
||||
batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
|
||||
labels=batch["chosen_labels"]
|
||||
labels=torch.tensor(batch["chosen_labels"])
|
||||
)
|
||||
|
||||
return batch
|
||||
@ -1076,6 +1080,8 @@ class DPOTrainer(Trainer):
|
||||
with compute_loss_context_manager():
|
||||
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
|
||||
|
||||
# Make sure to move the loss to the device the original accumulating loss is at back in the `Trainer` class:
|
||||
loss = loss.to(self.args.device)
|
||||
# force log the metrics
|
||||
self.store_metrics(metrics, train_eval="train")
|
||||
|
||||
@ -1086,7 +1092,7 @@ class DPOTrainer(Trainer):
|
||||
def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]:
|
||||
"""Generate samples from the model and reference model for the given batch of inputs."""
|
||||
|
||||
# If one uses `generate_during_eval` with peft + bf16, we need to explictly call generate with
|
||||
# If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
|
||||
# the torch cuda amp context manager as some hidden states are silently casted to full precision.
|
||||
generate_context_manager = nullcontext if not self._peft_has_been_casted_to_bf16 else torch.cuda.amp.autocast
|
||||
|
||||
@ -1242,7 +1248,7 @@ class DPOTrainer(Trainer):
|
||||
@wraps(Trainer.push_to_hub)
|
||||
def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str:
|
||||
"""
|
||||
Overwrite the `push_to_hub` method in order to force-add the tag "sft" when pushing the
|
||||
Overwrite the `push_to_hub` method in order to force-add the tag "dpo" when pushing the
|
||||
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
||||
"""
|
||||
kwargs = trl_sanitze_kwargs_for_tagging(model=self.model, tag_names=self._tag_names, kwargs=kwargs)
|
||||
|
84
trl/trainer/kto_config.py
Normal file
84
trl/trainer/kto_config.py
Normal file
@ -0,0 +1,84 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional
|
||||
|
||||
from transformers import TrainingArguments
|
||||
|
||||
|
||||
@dataclass
|
||||
class KTOConfig(TrainingArguments):
|
||||
r"""
|
||||
KTOConfig collects all training arguments related to the [`KTOTrainer`] class.
|
||||
|
||||
Using [`HfArgumentParser`] we can turn this class into
|
||||
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
||||
command line.
|
||||
|
||||
Parameters:
|
||||
max_length (`int`, *optional*, defaults to `None`):
|
||||
The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator.
|
||||
max_prompt_length (`int`, *optional*, defaults to `None`):
|
||||
The maximum length of the prompt. This argument is required if you want to use the default data collator.
|
||||
max_completion_length (`int`, *optional*, defaults to `None`):
|
||||
The maximum length of the target. This argument is required if you want to use the default data collator and your model is an encoder-decoder.
|
||||
beta (`float`, defaults to 0.1):
|
||||
The beta factor in KTO loss. Higher beta means less divergence from the initial policy.
|
||||
desirable_weight (`float`, *optional*, defaults to 1.0):
|
||||
The desirable losses are weighed by this factor to counter unequal number of desirable and undesirable paris.
|
||||
undesirable_weight (`float`, *optional*, defaults to 1.0):
|
||||
The undesirable losses are weighed by this factor to counter unequal number of desirable and undesirable pairs.
|
||||
label_pad_token_id (`int`, defaults to `-100`):
|
||||
The label pad token id. This argument is required if you want to use the default data collator.
|
||||
padding_value (`int`, defaults to `0`):
|
||||
The padding value if it is different to the tokenizer's pad_token_id.
|
||||
truncation_mode (`str`, defaults to `keep_end`):
|
||||
The truncation mode to use, either `keep_end` or `keep_start`. This argument is required if you want to use the default data collator.
|
||||
generate_during_eval (`bool`, defaults to `False`):
|
||||
Whether to sample and log generations during evaluation step.
|
||||
is_encoder_decoder (`Optional[bool]`, `optional`, defaults to `None`):
|
||||
If no model is provided, we need to know if the model_init returns an encoder-decoder.
|
||||
precompute_ref_log_probs (`bool`, defaults to `False`):
|
||||
Flag to precompute reference model log probabilities for training and evaluation datasets. This is useful if you want to train
|
||||
without the reference model and reduce the total GPU memory needed.
|
||||
model_init_kwargs: (`Optional[Dict]`, *optional*):
|
||||
Dict of Optional kwargs to pass when instantiating the model from a string.
|
||||
ref_model_init_kwargs: (`Optional[Dict]`, *optional*):
|
||||
Dict of Optional kwargs to pass when instantiating the ref model from a string.
|
||||
dataset_num_proc: (`Optional[int]`, *optional*, defaults to `None`):
|
||||
Number of processes to use for processing the datasets.
|
||||
"""
|
||||
|
||||
max_length: Optional[int] = None
|
||||
"""The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator."""
|
||||
max_prompt_length: Optional[int] = None
|
||||
"""The maximum length of the prompt. This argument is required if you want to use the default data collator."""
|
||||
max_completion_length: Optional[int] = None
|
||||
"""The maximum length of the target. This argument is required if you want to use the default data collator and your model is an encoder-decoder."""
|
||||
beta: float = 0.1
|
||||
"""The beta factor in KTO loss. Higher beta means less divergence from the initial policy."""
|
||||
desirable_weight: Optional[float] = 1.0
|
||||
"""The desirable losses are weighed by this factor."""
|
||||
undesirable_weight: Optional[float] = 1.0
|
||||
"""The undesirable losses are weighed by this factor."""
|
||||
|
||||
label_pad_token_id: int = -100
|
||||
padding_value: int = None
|
||||
truncation_mode: str = "keep_end"
|
||||
generate_during_eval: bool = False
|
||||
is_encoder_decoder: Optional[bool] = None
|
||||
precompute_ref_log_probs: bool = False
|
||||
model_init_kwargs: Optional[Dict] = None
|
||||
ref_model_init_kwargs: Optional[Dict] = None
|
||||
dataset_num_proc: Optional[int] = None
|
1314
trl/trainer/kto_trainer.py
Normal file
1314
trl/trainer/kto_trainer.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -61,6 +61,9 @@ class ModelConfig:
|
||||
default=None,
|
||||
metadata={"help": ("Model layers to unfreeze & train")},
|
||||
)
|
||||
lora_task_type: str = field(
|
||||
default="CAUSAL_LM", metadata={"help": "The task_type to pass for LoRA (use SEQ_CLS for reward modeling)"}
|
||||
)
|
||||
load_in_8bit: bool = field(
|
||||
default=False, metadata={"help": "use 8 bit precision for the base model - works only with LoRA"}
|
||||
)
|
||||
@ -82,3 +85,6 @@ class ModelConfig:
|
||||
def __post_init__(self):
|
||||
if self.load_in_8bit and self.load_in_4bit:
|
||||
raise ValueError("You can't use 8 bit and 4 bit precision at the same time")
|
||||
|
||||
if self.lora_target_modules == ["all-linear"]:
|
||||
self.lora_target_modules = "all-linear"
|
||||
|
71
trl/trainer/orpo_config.py
Normal file
71
trl/trainer/orpo_config.py
Normal file
@ -0,0 +1,71 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional
|
||||
|
||||
from transformers import TrainingArguments
|
||||
|
||||
|
||||
@dataclass
|
||||
class ORPOConfig(TrainingArguments):
|
||||
r"""
|
||||
ORPOConfig collects all training arguments related to the [`ORPOTrainer`] class.
|
||||
|
||||
Using [`HfArgumentParser`] we can turn this class into
|
||||
[argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
|
||||
command line.
|
||||
|
||||
Parameters:
|
||||
max_length (`int`, defaults to `None`):
|
||||
The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator.
|
||||
max_prompt_length (`int`, defaults to `None`):
|
||||
The maximum length of the prompt. This argument is required if you want to use the default data collator.
|
||||
max_completion_length (`int`, defaults to `None`):
|
||||
The maximum length of the completions. This argument is required if you want to use the default data collator and your model is an encoder-decoder.
|
||||
beta (`float`, defaults to 0.1):
|
||||
The beta factor in ORPO loss (lambda/alpha in paper/code) that is the weight of the relative loss ratio in the SFT loss.
|
||||
label_pad_token_id (`int`, defaults to `-100`):
|
||||
The label pad token id. This argument is required if you want to use the default data collator.
|
||||
padding_value (`int`, defaults to `None`):
|
||||
The padding value if it is different to the tokenizer's pad_token_id.
|
||||
truncation_mode (`str`, defaults to `keep_end`):
|
||||
The truncation mode to use, either `keep_end` or `keep_start`. This argument is required if you want to use the default data collator.
|
||||
generate_during_eval (`bool`, defaults to `False`):
|
||||
Whether to sample and log generations during evaluation step.
|
||||
is_encoder_decoder (`Optional[bool]`, `optional`, defaults to `None`):
|
||||
If no model is provided, we need to know if the model_init returns an encoder-decoder.
|
||||
disable_dropout (`bool`, defaults to `True`):
|
||||
Whether or not to disable dropouts in `model`.
|
||||
model_init_kwargs (`Optional[Dict]`, *optional*):
|
||||
Dict of Optional kwargs to pass when instantiating the model from a string
|
||||
dataset_num_proc (`Optional[int]`, *optional*):
|
||||
The number of workers to use to tokenize the data. Defaults to None.
|
||||
"""
|
||||
|
||||
max_length: Optional[int] = None
|
||||
max_prompt_length: Optional[int] = None
|
||||
max_completion_length: Optional[int] = None
|
||||
|
||||
beta: float = 0.1
|
||||
disable_dropout: bool = True
|
||||
|
||||
label_pad_token_id: int = -100
|
||||
padding_value: int = None
|
||||
truncation_mode: str = "keep_end"
|
||||
generate_during_eval: bool = False
|
||||
is_encoder_decoder: Optional[bool] = None
|
||||
|
||||
model_init_kwargs: Optional[Dict] = None
|
||||
|
||||
dataset_num_proc: Optional[int] = None
|
955
trl/trainer/orpo_trainer.py
Normal file
955
trl/trainer/orpo_trainer.py
Normal file
@ -0,0 +1,955 @@
|
||||
# ORPO Authors: Jiwoo Hong, Noah Lee, and James Thorne
|
||||
# Official code: https://github.com/xfactlab/orpo
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import random
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from contextlib import nullcontext
|
||||
from copy import deepcopy
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from accelerate import PartialState
|
||||
from accelerate.utils import is_deepspeed_available
|
||||
from datasets import Dataset
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import AutoModelForCausalLM, DataCollator, PreTrainedModel, PreTrainedTokenizerBase, Trainer
|
||||
from transformers.trainer_callback import TrainerCallback
|
||||
from transformers.trainer_utils import EvalLoopOutput
|
||||
from transformers.utils import is_torch_fx_proxy
|
||||
|
||||
from ..import_utils import is_peft_available, is_wandb_available
|
||||
from ..models import PreTrainedModelWrapper
|
||||
from .orpo_config import ORPOConfig
|
||||
from .utils import (
|
||||
DPODataCollatorWithPadding,
|
||||
disable_dropout_in_model,
|
||||
pad_to_length,
|
||||
peft_module_casting_to_bf16,
|
||||
trl_sanitze_kwargs_for_tagging,
|
||||
)
|
||||
|
||||
|
||||
if is_peft_available():
|
||||
from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training
|
||||
|
||||
|
||||
if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
if is_deepspeed_available():
|
||||
import deepspeed
|
||||
|
||||
|
||||
class ORPOTrainer(Trainer):
|
||||
r"""
|
||||
Initialize ORPOTrainer.
|
||||
|
||||
Args:
|
||||
model (`transformers.PreTrainedModel`):
|
||||
The model to train, preferably an `AutoModelForSequenceClassification`.
|
||||
args (`ORPOConfig`):
|
||||
The ORPO config arguments to use for training.
|
||||
data_collator (`transformers.DataCollator`):
|
||||
The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
|
||||
which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
|
||||
train_dataset (`datasets.Dataset`):
|
||||
The dataset to use for training.
|
||||
eval_dataset (`datasets.Dataset`):
|
||||
The dataset to use for evaluation.
|
||||
tokenizer (`transformers.PreTrainedTokenizerBase`):
|
||||
The tokenizer to use for training. This argument is required if you want to use the default data collator.
|
||||
model_init (`Callable[[], transformers.PreTrainedModel]`):
|
||||
The model initializer to use for training. If None is specified, the default model initializer will be used.
|
||||
callbacks (`List[transformers.TrainerCallback]`):
|
||||
The callbacks to use for training.
|
||||
optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
|
||||
The optimizer and scheduler to use for training.
|
||||
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
|
||||
The function to use to preprocess the logits before computing the metrics.
|
||||
peft_config (`Dict`, defaults to `None`):
|
||||
The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
|
||||
compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*):
|
||||
The function to use to compute the metrics. Must take a `EvalPrediction` and return
|
||||
a dictionary string to metric values.
|
||||
"""
|
||||
|
||||
_tag_names = ["trl", "orpo"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
|
||||
args: Optional[ORPOConfig] = None,
|
||||
data_collator: Optional[DataCollator] = None,
|
||||
train_dataset: Optional[Dataset] = None,
|
||||
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
|
||||
tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
||||
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
||||
callbacks: Optional[List[TrainerCallback]] = None,
|
||||
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
||||
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
||||
peft_config: Optional[Dict] = None,
|
||||
compute_metrics: Optional[Callable[[EvalLoopOutput], Dict]] = None,
|
||||
):
|
||||
if args.model_init_kwargs is None:
|
||||
model_init_kwargs = {}
|
||||
elif not isinstance(model, str):
|
||||
raise ValueError("You passed model_kwargs to the ORPOTrainer. But your model is already instantiated.")
|
||||
else:
|
||||
model_init_kwargs = args.model_init_kwargs
|
||||
model_init_kwargs["torch_dtype"] = (
|
||||
model_init_kwargs["torch_dtype"]
|
||||
if model_init_kwargs["torch_dtype"] in ["auto", None]
|
||||
else getattr(torch, model_init_kwargs["torch_dtype"])
|
||||
)
|
||||
|
||||
if isinstance(model, str):
|
||||
warnings.warn(
|
||||
"You passed a model_id to the ORPOTrainer. This will automatically create an "
|
||||
"`AutoModelForCausalLM` or a `PeftModel` (if you passed a `peft_config`) for you."
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
|
||||
|
||||
# Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
|
||||
# has been called in order to properly call autocast if needed.
|
||||
self._peft_has_been_casted_to_bf16 = False
|
||||
|
||||
if not is_peft_available() and peft_config is not None:
|
||||
raise ValueError(
|
||||
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
|
||||
)
|
||||
elif is_peft_available() and peft_config is not None:
|
||||
# if model is a peft model and we have a peft_config, we merge and unload it first
|
||||
if isinstance(model, PeftModel):
|
||||
model = model.merge_and_unload()
|
||||
|
||||
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
|
||||
_support_gc_kwargs = hasattr(
|
||||
args, "gradient_checkpointing_kwargs"
|
||||
) and "gradient_checkpointing_kwargs" in list(
|
||||
inspect.signature(prepare_model_for_kbit_training).parameters
|
||||
)
|
||||
|
||||
prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
|
||||
|
||||
if _support_gc_kwargs:
|
||||
prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
|
||||
|
||||
model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
|
||||
elif getattr(args, "gradient_checkpointing", False):
|
||||
# For backward compatibility with older versions of transformers
|
||||
if hasattr(model, "enable_input_require_grads"):
|
||||
model.enable_input_require_grads()
|
||||
else:
|
||||
|
||||
def make_inputs_require_grad(module, input, output):
|
||||
output.requires_grad_(True)
|
||||
|
||||
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
||||
|
||||
# get peft model with the given config
|
||||
model = get_peft_model(model, peft_config)
|
||||
if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
|
||||
peft_module_casting_to_bf16(model)
|
||||
# If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
|
||||
self._peft_has_been_casted_to_bf16 = True
|
||||
|
||||
# For models that use gradient_checkpointing, we need to attach a hook that enables input
|
||||
# to explicitly have `requires_grad=True`, otherwise training will either silently
|
||||
# fail or completely fail.
|
||||
elif getattr(args, "gradient_checkpointing", False):
|
||||
# For backward compatibility with older versions of transformers
|
||||
if hasattr(model, "enable_input_require_grads"):
|
||||
model.enable_input_require_grads()
|
||||
else:
|
||||
|
||||
def make_inputs_require_grad(module, input, output):
|
||||
output.requires_grad_(True)
|
||||
|
||||
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
||||
|
||||
if args.generate_during_eval and not is_wandb_available():
|
||||
raise ValueError(
|
||||
"`generate_during_eval=True` requires Weights and Biases to be installed."
|
||||
" Please install `wandb` to resolve."
|
||||
)
|
||||
|
||||
if model is not None:
|
||||
self.is_encoder_decoder = model.config.is_encoder_decoder
|
||||
elif args.is_encoder_decoder is None:
|
||||
raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
|
||||
else:
|
||||
self.is_encoder_decoder = args.is_encoder_decoder
|
||||
|
||||
if self.is_encoder_decoder:
|
||||
self.decoder_start_token_id = model.config.decoder_start_token_id
|
||||
self.pad_token_id = model.config.pad_token_id
|
||||
|
||||
if tokenizer is None:
|
||||
raise ValueError("tokenizer must be specified to tokenize a ORPO dataset.")
|
||||
if args.max_length is None:
|
||||
warnings.warn(
|
||||
"`max_length` is not set in the ORPOConfig's init"
|
||||
" it will default to `512` by default, but you should do it yourself in the future.",
|
||||
UserWarning,
|
||||
)
|
||||
max_length = 512
|
||||
else:
|
||||
max_length = args.max_length
|
||||
if args.max_prompt_length is None:
|
||||
warnings.warn(
|
||||
"`max_prompt_length` is not set in the ORPOConfig's init"
|
||||
" it will default to `128` by default, but you should do it yourself in the future.",
|
||||
UserWarning,
|
||||
)
|
||||
max_prompt_length = 128
|
||||
else:
|
||||
max_prompt_length = args.max_prompt_length
|
||||
|
||||
if args.max_completion_length is None and self.is_encoder_decoder:
|
||||
warnings.warn(
|
||||
"When using an encoder decoder architecture, you should set `max_completion_length` in the ORPOConfig's init"
|
||||
" it will default to `128` by default, but you should do it yourself in the future.",
|
||||
UserWarning,
|
||||
)
|
||||
self.max_completion_length = 128
|
||||
else:
|
||||
self.max_completion_length = args.max_completion_length
|
||||
|
||||
if data_collator is None:
|
||||
data_collator = DPODataCollatorWithPadding(
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
label_pad_token_id=args.label_pad_token_id,
|
||||
is_encoder_decoder=self.is_encoder_decoder,
|
||||
)
|
||||
|
||||
if args.remove_unused_columns:
|
||||
args.remove_unused_columns = False
|
||||
# warn users
|
||||
warnings.warn(
|
||||
"When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your TrainingArguments"
|
||||
" we have set it for you, but you should do it yourself in the future.",
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
self.use_dpo_data_collator = True
|
||||
else:
|
||||
self.use_dpo_data_collator = False
|
||||
|
||||
if args.disable_dropout:
|
||||
disable_dropout_in_model(model)
|
||||
|
||||
self.max_length = max_length
|
||||
self.generate_during_eval = args.generate_during_eval
|
||||
self.label_pad_token_id = args.label_pad_token_id
|
||||
self.padding_value = args.padding_value if args.padding_value is not None else tokenizer.pad_token_id
|
||||
self.max_prompt_length = max_prompt_length
|
||||
self.truncation_mode = args.truncation_mode
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
self.beta = args.beta
|
||||
|
||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||
|
||||
# Compute that only on the main process for faster data processing.
|
||||
# see: https://github.com/huggingface/trl/pull/1255
|
||||
with PartialState().local_main_process_first():
|
||||
# tokenize the dataset
|
||||
train_dataset = train_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
|
||||
if eval_dataset is not None:
|
||||
eval_dataset = eval_dataset.map(self.tokenize_row, num_proc=args.dataset_num_proc)
|
||||
|
||||
super().__init__(
|
||||
model=model,
|
||||
args=args,
|
||||
data_collator=data_collator,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
tokenizer=tokenizer,
|
||||
model_init=model_init,
|
||||
compute_metrics=compute_metrics,
|
||||
callbacks=callbacks,
|
||||
optimizers=optimizers,
|
||||
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
||||
)
|
||||
|
||||
# Add tags for models that have been loaded with the correct transformers version
|
||||
if hasattr(self.model, "add_model_tags"):
|
||||
self.model.add_model_tags(self._tag_names)
|
||||
|
||||
if not hasattr(self, "accelerator"):
|
||||
raise AttributeError(
|
||||
"Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
|
||||
)
|
||||
|
||||
def _prepare_deepspeed(self, model: PreTrainedModelWrapper):
|
||||
# Adapted from accelerate: https://github.com/huggingface/accelerate/blob/739b135f8367becb67ffaada12fe76e3aa60fefd/src/accelerate/accelerator.py#L1473
|
||||
deepspeed_plugin = self.accelerator.state.deepspeed_plugin
|
||||
config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config)
|
||||
|
||||
if model is not None:
|
||||
if hasattr(model, "config"):
|
||||
hidden_size = (
|
||||
max(model.config.hidden_sizes)
|
||||
if getattr(model.config, "hidden_sizes", None)
|
||||
else getattr(model.config, "hidden_size", None)
|
||||
)
|
||||
if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3:
|
||||
# Note that `stage3_prefetch_bucket_size` can produce DeepSpeed messages like: `Invalidate trace cache @ step 0: expected module 1, but got module 0`
|
||||
# This is expected and is not an error, see: https://github.com/microsoft/DeepSpeed/discussions/4081
|
||||
config_kwargs.update(
|
||||
{
|
||||
"zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
|
||||
"zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
|
||||
"zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size,
|
||||
}
|
||||
)
|
||||
|
||||
# If ZeRO-3 is used, we shard both the active and reference model.
|
||||
# Otherwise, we assume the reference model fits in memory and is initialized on each device with ZeRO disabled (stage 0)
|
||||
if config_kwargs["zero_optimization"]["stage"] != 3:
|
||||
config_kwargs["zero_optimization"]["stage"] = 0
|
||||
model, *_ = deepspeed.initialize(model=model, config=config_kwargs)
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
def build_tokenized_answer(self, prompt, answer):
|
||||
"""
|
||||
Llama tokenizer does satisfy `enc(a + b) = enc(a) + enc(b)`.
|
||||
It does ensure `enc(a + b) = enc(a) + enc(a + b)[len(enc(a)):]`.
|
||||
Reference:
|
||||
https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
|
||||
"""
|
||||
|
||||
full_tokenized = self.tokenizer(prompt + answer, add_special_tokens=False)
|
||||
prompt_input_ids = self.tokenizer(prompt, add_special_tokens=False)["input_ids"]
|
||||
|
||||
answer_input_ids = full_tokenized["input_ids"][len(prompt_input_ids) :]
|
||||
answer_attention_mask = full_tokenized["attention_mask"][len(prompt_input_ids) :]
|
||||
|
||||
# Concat tokens to form `enc(a) + enc(a + b)[len(enc(a)):]`
|
||||
full_concat_input_ids = np.concatenate([prompt_input_ids, answer_input_ids])
|
||||
|
||||
# Prepare input tokens for token by token comparison
|
||||
full_input_ids = np.array(full_tokenized["input_ids"])
|
||||
|
||||
if len(full_input_ids) != len(full_concat_input_ids):
|
||||
raise ValueError("Prompt input ids and answer input ids should have the same length.")
|
||||
|
||||
# On some tokenizers, like Llama-2 tokenizer, there are occasions where tokens
|
||||
# can be merged together when tokenizing prompt+answer. This could result
|
||||
# on the last token from the prompt being different when tokenized on its own
|
||||
# vs when done as prompt+answer.
|
||||
response_token_ids_start_idx = len(prompt_input_ids)
|
||||
|
||||
# If tokenized prompt is different than both prompt+answer, then it means the
|
||||
# last token has changed due to merging.
|
||||
if prompt_input_ids != full_tokenized["input_ids"][:response_token_ids_start_idx]:
|
||||
response_token_ids_start_idx -= 1
|
||||
|
||||
prompt_input_ids = full_tokenized["input_ids"][:response_token_ids_start_idx]
|
||||
prompt_attention_mask = full_tokenized["attention_mask"][:response_token_ids_start_idx]
|
||||
|
||||
if len(prompt_input_ids) != len(prompt_attention_mask):
|
||||
raise ValueError("Prompt input ids and attention mask should have the same length.")
|
||||
|
||||
answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:]
|
||||
answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:]
|
||||
|
||||
return dict(
|
||||
prompt_input_ids=prompt_input_ids,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
input_ids=answer_input_ids,
|
||||
attention_mask=answer_attention_mask,
|
||||
)
|
||||
|
||||
def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module]] = None) -> Dict:
|
||||
"""Tokenize a single row from a ORPO specific dataset.
|
||||
|
||||
At this stage, we don't convert to PyTorch tensors yet; we just handle the truncation
|
||||
in case the prompt + chosen or prompt + rejected responses is/are too long. First
|
||||
we truncate the prompt; if we're still too long, we truncate the chosen/rejected.
|
||||
|
||||
We also create the labels for the chosen/rejected responses, which are of length equal to
|
||||
the sum of the length of the prompt and the chosen/rejected response, with
|
||||
label_pad_token_id for the prompt tokens.
|
||||
"""
|
||||
batch = {}
|
||||
prompt = feature["prompt"]
|
||||
chosen = feature["chosen"]
|
||||
rejected = feature["rejected"]
|
||||
|
||||
if not self.is_encoder_decoder:
|
||||
# Check issues below for more details
|
||||
# 1. https://github.com/huggingface/trl/issues/907
|
||||
# 2. https://github.com/EleutherAI/lm-evaluation-harness/pull/531#issuecomment-1595586257
|
||||
# 3. https://github.com/LianjiaTech/BELLE/issues/337
|
||||
|
||||
if not isinstance(prompt, str):
|
||||
raise ValueError(f"prompt should be an str but got {type(prompt)}")
|
||||
prompt_tokens = self.tokenizer(prompt, add_special_tokens=False)
|
||||
prompt_tokens = {f"prompt_{k}": v for k, v in prompt_tokens.items()}
|
||||
|
||||
if not isinstance(chosen, str):
|
||||
raise ValueError(f"chosen should be an str but got {type(chosen)}")
|
||||
chosen_tokens = self.build_tokenized_answer(prompt, chosen)
|
||||
|
||||
if not isinstance(rejected, str):
|
||||
raise ValueError(f"rejected should be an str but got {type(rejected)}")
|
||||
rejected_tokens = self.build_tokenized_answer(prompt, rejected)
|
||||
|
||||
# Last prompt token might get merged by tokenizer and
|
||||
# it should not be included for generation if that happens
|
||||
prompt_len_input_ids = len(prompt_tokens["prompt_input_ids"])
|
||||
|
||||
chosen_prompt_len_input_ids = len(chosen_tokens["prompt_input_ids"])
|
||||
rejected_prompt_len_input_ids = len(rejected_tokens["prompt_input_ids"])
|
||||
prompt_len_input_ids = min(chosen_prompt_len_input_ids, rejected_prompt_len_input_ids)
|
||||
|
||||
for k, v in prompt_tokens.items():
|
||||
prompt_tokens[k] = v[:prompt_len_input_ids]
|
||||
|
||||
# Make sure prompts only have one different token at most an
|
||||
# and length only differs by 1 at most
|
||||
num_diff_tokens = sum(
|
||||
[a != b for a, b in zip(chosen_tokens["prompt_input_ids"], rejected_tokens["prompt_input_ids"])]
|
||||
)
|
||||
num_diff_len = abs(chosen_prompt_len_input_ids - rejected_prompt_len_input_ids)
|
||||
if num_diff_tokens > 1 or num_diff_len > 1:
|
||||
raise ValueError(
|
||||
"Chosen and rejected prompt_input_ids might only differ on the "
|
||||
"last token due to tokenizer merge ops."
|
||||
)
|
||||
|
||||
# add BOS token to head of prompt
|
||||
prompt_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + prompt_tokens["prompt_input_ids"]
|
||||
chosen_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + chosen_tokens["prompt_input_ids"]
|
||||
rejected_tokens["prompt_input_ids"] = [self.tokenizer.bos_token_id] + rejected_tokens["prompt_input_ids"]
|
||||
|
||||
prompt_tokens["prompt_attention_mask"] = [1] + prompt_tokens["prompt_attention_mask"]
|
||||
chosen_tokens["prompt_attention_mask"] = [1] + chosen_tokens["prompt_attention_mask"]
|
||||
rejected_tokens["prompt_attention_mask"] = [1] + rejected_tokens["prompt_attention_mask"]
|
||||
|
||||
# add EOS token to end of answer
|
||||
chosen_tokens["input_ids"].append(self.tokenizer.eos_token_id)
|
||||
chosen_tokens["attention_mask"].append(1)
|
||||
|
||||
rejected_tokens["input_ids"].append(self.tokenizer.eos_token_id)
|
||||
rejected_tokens["attention_mask"].append(1)
|
||||
|
||||
longer_response_length = max(len(chosen_tokens["input_ids"]), len(rejected_tokens["input_ids"]))
|
||||
|
||||
# if combined sequence is too long, truncate the prompt
|
||||
for answer_tokens in [chosen_tokens, rejected_tokens, prompt_tokens]:
|
||||
if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
|
||||
if self.truncation_mode == "keep_start":
|
||||
for k in ["prompt_input_ids", "prompt_attention_mask"]:
|
||||
answer_tokens[k] = answer_tokens[k][: self.max_prompt_length]
|
||||
elif self.truncation_mode == "keep_end":
|
||||
for k in ["prompt_input_ids", "prompt_attention_mask"]:
|
||||
answer_tokens[k] = answer_tokens[k][-self.max_prompt_length :]
|
||||
else:
|
||||
raise ValueError(f"Unknown truncation mode: {self.truncation_mode}")
|
||||
|
||||
# if that's still too long, truncate the response
|
||||
for answer_tokens in [chosen_tokens, rejected_tokens]:
|
||||
if len(answer_tokens["prompt_input_ids"]) + longer_response_length > self.max_length:
|
||||
for k in ["input_ids", "attention_mask"]:
|
||||
answer_tokens[k] = answer_tokens[k][: self.max_length - self.max_prompt_length]
|
||||
|
||||
# Create labels
|
||||
chosen_sequence_tokens = {
|
||||
k: chosen_tokens[f"prompt_{k}"] + chosen_tokens[k] for k in ["input_ids", "attention_mask"]
|
||||
}
|
||||
rejected_sequence_tokens = {
|
||||
k: rejected_tokens[f"prompt_{k}"] + rejected_tokens[k] for k in ["input_ids", "attention_mask"]
|
||||
}
|
||||
chosen_sequence_tokens["labels"] = chosen_sequence_tokens["input_ids"][:]
|
||||
chosen_sequence_tokens["labels"][: len(chosen_tokens["prompt_input_ids"])] = [
|
||||
self.label_pad_token_id
|
||||
] * len(chosen_tokens["prompt_input_ids"])
|
||||
rejected_sequence_tokens["labels"] = rejected_sequence_tokens["input_ids"][:]
|
||||
rejected_sequence_tokens["labels"][: len(rejected_tokens["prompt_input_ids"])] = [
|
||||
self.label_pad_token_id
|
||||
] * len(rejected_tokens["prompt_input_ids"])
|
||||
|
||||
for k, toks in {
|
||||
"chosen_": chosen_sequence_tokens,
|
||||
"rejected_": rejected_sequence_tokens,
|
||||
"": prompt_tokens,
|
||||
}.items():
|
||||
for type_key, tokens in toks.items():
|
||||
if type_key == "token_type_ids":
|
||||
continue
|
||||
batch[f"{k}{type_key}"] = tokens
|
||||
|
||||
else:
|
||||
chosen_tokens = self.tokenizer(
|
||||
chosen, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
|
||||
)
|
||||
rejected_tokens = self.tokenizer(
|
||||
rejected, truncation=True, max_length=self.max_completion_length, add_special_tokens=True
|
||||
)
|
||||
prompt_tokens = self.tokenizer(
|
||||
prompt, truncation=True, max_length=self.max_prompt_length, add_special_tokens=True
|
||||
)
|
||||
|
||||
batch["chosen_labels"] = chosen_tokens["input_ids"]
|
||||
batch["rejected_labels"] = rejected_tokens["input_ids"]
|
||||
batch["prompt_input_ids"] = prompt_tokens["input_ids"]
|
||||
batch["prompt_attention_mask"] = prompt_tokens["attention_mask"]
|
||||
|
||||
if model is not None and hasattr(model, "prepare_decoder_input_ids_from_labels"):
|
||||
batch["rejected_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
|
||||
labels=torch.tensor(batch["rejected_labels"])
|
||||
)
|
||||
batch["chosen_decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
|
||||
labels=torch.tensor(batch["chosen_labels"])
|
||||
)
|
||||
|
||||
return batch
|
||||
|
||||
@staticmethod
|
||||
def concatenated_inputs(
|
||||
batch: Dict[str, Union[List, torch.LongTensor]],
|
||||
is_encoder_decoder: bool = False,
|
||||
label_pad_token_id: int = -100,
|
||||
padding_value: int = 0,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> Dict[str, torch.LongTensor]:
|
||||
"""Concatenate the chosen and rejected inputs into a single tensor.
|
||||
|
||||
Args:
|
||||
batch: A batch of data. Must contain the keys 'chosen_input_ids' and 'rejected_input_ids', which are tensors of shape (batch_size, sequence_length).
|
||||
is_encoder_decoder: Whether the model is an encoder-decoder model.
|
||||
label_pad_token_id: The label pad token id.
|
||||
padding_value: The padding value to use for the concatenated inputs_ids.
|
||||
device: The device for the concatenated inputs.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the concatenated inputs under the key 'concatenated_input_ids'.
|
||||
"""
|
||||
concatenated_batch = {}
|
||||
|
||||
if is_encoder_decoder:
|
||||
max_length = max(batch["chosen_labels"].shape[1], batch["rejected_labels"].shape[1])
|
||||
else:
|
||||
max_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1])
|
||||
|
||||
for k in batch:
|
||||
if k.startswith("chosen") and isinstance(batch[k], torch.Tensor):
|
||||
if "labels" in k or is_encoder_decoder:
|
||||
pad_value = label_pad_token_id
|
||||
elif k.endswith("_input_ids"):
|
||||
pad_value = padding_value
|
||||
elif k.endswith("_attention_mask"):
|
||||
pad_value = 0
|
||||
concatenated_key = k.replace("chosen", "concatenated")
|
||||
concatenated_batch[concatenated_key] = pad_to_length(batch[k], max_length, pad_value=pad_value)
|
||||
for k in batch:
|
||||
if k.startswith("rejected") and isinstance(batch[k], torch.Tensor):
|
||||
if "labels" in k or is_encoder_decoder:
|
||||
pad_value = label_pad_token_id
|
||||
elif k.endswith("_input_ids"):
|
||||
pad_value = padding_value
|
||||
elif k.endswith("_attention_mask"):
|
||||
pad_value = 0
|
||||
concatenated_key = k.replace("rejected", "concatenated")
|
||||
concatenated_batch[concatenated_key] = torch.cat(
|
||||
(
|
||||
concatenated_batch[concatenated_key],
|
||||
pad_to_length(batch[k], max_length, pad_value=pad_value),
|
||||
),
|
||||
dim=0,
|
||||
).to(device=device)
|
||||
|
||||
if is_encoder_decoder:
|
||||
concatenated_batch["concatenated_input_ids"] = batch["prompt_input_ids"].repeat(2, 1).to(device=device)
|
||||
concatenated_batch["concatenated_attention_mask"] = (
|
||||
batch["prompt_attention_mask"].repeat(2, 1).to(device=device)
|
||||
)
|
||||
|
||||
return concatenated_batch
|
||||
|
||||
def odds_ratio_loss(
|
||||
self,
|
||||
policy_chosen_logps: torch.FloatTensor,
|
||||
policy_rejected_logps: torch.FloatTensor,
|
||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
||||
"""Compute ORPO's odds ratio (OR) loss for a batch of policy and reference model log probabilities.
|
||||
|
||||
Args:
|
||||
policy_chosen_logps: Log probabilities of the policy model for the chosen responses. Shape: (batch_size,)
|
||||
policy_rejected_logps: Log probabilities of the policy model for the rejected responses. Shape: (batch_size,)
|
||||
|
||||
Returns:
|
||||
A tuple of three tensors: (losses, chosen_rewards, rejected_rewards).
|
||||
The losses tensor contains the ORPO loss for each example in the batch.
|
||||
The chosen_rewards and rejected_rewards tensors contain the rewards for the chosen and rejected responses, respectively.
|
||||
The log odds ratio of the chosen responses over the rejected responses ratio for logging purposes.
|
||||
The `log(sigmoid(log_odds_chosen))` for logging purposes.
|
||||
"""
|
||||
|
||||
# Derived from Eqs. (4) and (7) from https://arxiv.org/abs/2403.07691 by using log identities and exp(log(P(y|x)) = P(y|x)
|
||||
log_odds = (policy_chosen_logps - policy_rejected_logps) - (
|
||||
torch.log1p(-torch.exp(policy_chosen_logps)) - torch.log1p(-torch.exp(policy_rejected_logps))
|
||||
)
|
||||
sig_ratio = F.sigmoid(log_odds)
|
||||
ratio = torch.log(sig_ratio)
|
||||
losses = self.beta * ratio
|
||||
|
||||
chosen_rewards = self.beta * (policy_chosen_logps.to(self.accelerator.device)).detach()
|
||||
rejected_rewards = self.beta * (policy_rejected_logps.to(self.accelerator.device)).detach()
|
||||
|
||||
return losses, chosen_rewards, rejected_rewards, torch.mean(ratio).item(), torch.mean(log_odds).item()
|
||||
|
||||
@staticmethod
|
||||
def get_batch_logps(
|
||||
logits: torch.FloatTensor,
|
||||
labels: torch.LongTensor,
|
||||
average_log_prob: bool = False,
|
||||
label_pad_token_id: int = -100,
|
||||
is_encoder_decoder: bool = False,
|
||||
) -> torch.FloatTensor:
|
||||
"""Compute the log probabilities of the given labels under the given logits.
|
||||
|
||||
Args:
|
||||
logits: Logits of the model (unnormalized). Shape: (batch_size, sequence_length, vocab_size)
|
||||
labels: Labels for which to compute the log probabilities. Label tokens with a value of label_pad_token_id are ignored. Shape: (batch_size, sequence_length)
|
||||
average_log_prob: If True, return the average log probability per (non-masked) token. Otherwise, return the sum of the log probabilities of the (non-masked) tokens.
|
||||
label_pad_token_id: The label pad token id.
|
||||
is_encoder_decoder: Whether the model is an encoder-decoder model.
|
||||
|
||||
Returns:
|
||||
A tensor of shape (batch_size,) containing the average/sum log probabilities of the given labels under the given logits.
|
||||
"""
|
||||
if logits.shape[:-1] != labels.shape:
|
||||
raise ValueError("Logits (batch and sequence length dim) and labels must have the same shape.")
|
||||
|
||||
if not is_encoder_decoder:
|
||||
labels = labels[:, 1:].clone()
|
||||
logits = logits[:, :-1, :]
|
||||
loss_mask = labels != label_pad_token_id
|
||||
|
||||
# dummy token; we'll ignore the losses on these tokens later
|
||||
labels[labels == label_pad_token_id] = 0
|
||||
|
||||
per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
|
||||
|
||||
if average_log_prob:
|
||||
return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
|
||||
else:
|
||||
return (per_token_logps * loss_mask).sum(-1)
|
||||
|
||||
def concatenated_forward(
|
||||
self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
|
||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
||||
"""Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
|
||||
|
||||
We do this to avoid doing two forward passes, because it's faster for FSDP.
|
||||
"""
|
||||
concatenated_batch = self.concatenated_inputs(
|
||||
batch,
|
||||
is_encoder_decoder=self.is_encoder_decoder,
|
||||
label_pad_token_id=self.label_pad_token_id,
|
||||
padding_value=self.padding_value,
|
||||
device=self.accelerator.device,
|
||||
)
|
||||
len_chosen = batch["chosen_labels"].shape[0]
|
||||
|
||||
model_kwargs = (
|
||||
{
|
||||
"decoder_input_ids": self._shift_right(concatenated_batch["concatenated_labels"]),
|
||||
}
|
||||
if self.is_encoder_decoder
|
||||
else {}
|
||||
)
|
||||
|
||||
outputs = model(
|
||||
concatenated_batch["concatenated_input_ids"],
|
||||
attention_mask=concatenated_batch["concatenated_attention_mask"],
|
||||
use_cache=False,
|
||||
**model_kwargs,
|
||||
)
|
||||
all_logits = outputs.logits
|
||||
|
||||
def cross_entropy_loss(logits, labels):
|
||||
if not self.is_encoder_decoder:
|
||||
# Shift so that tokens < n predict n
|
||||
logits = logits[..., :-1, :].contiguous()
|
||||
labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = nn.CrossEntropyLoss()
|
||||
logits = logits.view(-1, logits.shape[-1])
|
||||
labels = labels.view(-1)
|
||||
# Enable model parallelism
|
||||
labels = labels.to(logits.device)
|
||||
loss = loss_fct(logits, labels)
|
||||
return loss
|
||||
|
||||
if self.is_encoder_decoder:
|
||||
labels = concatenated_batch["concatenated_labels"].clone()
|
||||
else:
|
||||
labels = concatenated_batch["concatenated_input_ids"].clone()
|
||||
|
||||
chosen_nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])
|
||||
|
||||
all_logps = self.get_batch_logps(
|
||||
all_logits,
|
||||
concatenated_batch["concatenated_labels"],
|
||||
average_log_prob=True,
|
||||
is_encoder_decoder=self.is_encoder_decoder,
|
||||
label_pad_token_id=self.label_pad_token_id,
|
||||
)
|
||||
|
||||
chosen_logps = all_logps[:len_chosen]
|
||||
rejected_logps = all_logps[len_chosen:]
|
||||
|
||||
chosen_logits = all_logits[:len_chosen]
|
||||
rejected_logits = all_logits[len_chosen:]
|
||||
|
||||
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss)
|
||||
|
||||
def get_batch_loss_metrics(
|
||||
self,
|
||||
model,
|
||||
batch: Dict[str, Union[List, torch.LongTensor]],
|
||||
train_eval: Literal["train", "eval"] = "train",
|
||||
):
|
||||
"""Compute the ORPO loss and other metrics for the given batch of inputs for train or test."""
|
||||
metrics = {}
|
||||
|
||||
(
|
||||
policy_chosen_logps,
|
||||
policy_rejected_logps,
|
||||
policy_chosen_logits,
|
||||
policy_rejected_logits,
|
||||
policy_nll_loss,
|
||||
) = self.concatenated_forward(model, batch)
|
||||
|
||||
losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = self.odds_ratio_loss(
|
||||
policy_chosen_logps, policy_rejected_logps
|
||||
)
|
||||
# full ORPO loss
|
||||
loss = policy_nll_loss - losses.mean()
|
||||
|
||||
reward_accuracies = (chosen_rewards > rejected_rewards).float()
|
||||
|
||||
prefix = "eval_" if train_eval == "eval" else ""
|
||||
metrics[f"{prefix}rewards/chosen"] = chosen_rewards.mean().cpu()
|
||||
metrics[f"{prefix}rewards/rejected"] = rejected_rewards.mean().cpu()
|
||||
metrics[f"{prefix}rewards/accuracies"] = reward_accuracies.mean().cpu()
|
||||
metrics[f"{prefix}rewards/margins"] = (chosen_rewards - rejected_rewards).mean().cpu()
|
||||
metrics[f"{prefix}logps/rejected"] = policy_rejected_logps.detach().mean().cpu()
|
||||
metrics[f"{prefix}logps/chosen"] = policy_chosen_logps.detach().mean().cpu()
|
||||
metrics[f"{prefix}logits/rejected"] = policy_rejected_logits.detach().mean().cpu()
|
||||
metrics[f"{prefix}logits/chosen"] = policy_chosen_logits.detach().mean().cpu()
|
||||
metrics[f"{prefix}nll_loss"] = policy_nll_loss.detach().mean().cpu()
|
||||
metrics[f"{prefix}log_odds_ratio"] = log_odds_ratio
|
||||
metrics[f"{prefix}log_odds_chosen"] = log_odds_chosen
|
||||
|
||||
return loss, metrics
|
||||
|
||||
def compute_loss(
|
||||
self,
|
||||
model: Union[PreTrainedModel, nn.Module],
|
||||
inputs: Dict[str, Union[torch.Tensor, Any]],
|
||||
return_outputs=False,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, Dict[str, torch.Tensor]]]:
|
||||
if not self.use_dpo_data_collator:
|
||||
warnings.warn(
|
||||
"compute_loss is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
|
||||
"DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
|
||||
)
|
||||
|
||||
compute_loss_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext
|
||||
|
||||
with compute_loss_context_manager():
|
||||
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="train")
|
||||
|
||||
# force log the metrics
|
||||
self.store_metrics(metrics, train_eval="train")
|
||||
|
||||
if return_outputs:
|
||||
return (loss, metrics)
|
||||
return loss
|
||||
|
||||
def get_batch_samples(self, model, batch: Dict[str, torch.LongTensor]) -> Tuple[str, str]:
|
||||
"""Generate samples from the model and reference model for the given batch of inputs."""
|
||||
|
||||
# If one uses `generate_during_eval` with peft + bf16, we need to explicitly call generate with
|
||||
# the torch cuda amp context manager as some hidden states are silently casted to full precision.
|
||||
generate_context_manager = nullcontext if not self._peft_has_been_casted_to_bf16 else torch.cuda.amp.autocast
|
||||
|
||||
with generate_context_manager():
|
||||
policy_output = model.generate(
|
||||
input_ids=batch["prompt_input_ids"],
|
||||
attention_mask=batch["prompt_attention_mask"],
|
||||
max_length=self.max_length,
|
||||
do_sample=True,
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
)
|
||||
|
||||
policy_output = pad_to_length(policy_output, self.max_length, self.tokenizer.pad_token_id)
|
||||
policy_output_decoded = self.tokenizer.batch_decode(policy_output, skip_special_tokens=True)
|
||||
|
||||
return policy_output_decoded
|
||||
|
||||
def prediction_step(
|
||||
self,
|
||||
model: Union[PreTrainedModel, nn.Module],
|
||||
inputs: Dict[str, Union[torch.Tensor, Any]],
|
||||
prediction_loss_only: bool,
|
||||
ignore_keys: Optional[List[str]] = None,
|
||||
):
|
||||
if not self.use_dpo_data_collator:
|
||||
warnings.warn(
|
||||
"prediction_step is only implemented for DPODataCollatorWithPadding, and you passed a datacollator that is different than "
|
||||
"DPODataCollatorWithPadding - you might see unexpected behavior. Alternatively, you can implement your own prediction_step method if you are using a custom data collator"
|
||||
)
|
||||
if ignore_keys is None:
|
||||
if hasattr(model, "config"):
|
||||
ignore_keys = getattr(model.config, "keys_to_ignore_at_inference", [])
|
||||
else:
|
||||
ignore_keys = []
|
||||
|
||||
prediction_context_manager = torch.cuda.amp.autocast if self._peft_has_been_casted_to_bf16 else nullcontext
|
||||
|
||||
with torch.no_grad(), prediction_context_manager():
|
||||
loss, metrics = self.get_batch_loss_metrics(model, inputs, train_eval="eval")
|
||||
|
||||
# force log the metrics
|
||||
self.store_metrics(metrics, train_eval="eval")
|
||||
|
||||
if prediction_loss_only:
|
||||
return (loss.detach(), None, None)
|
||||
|
||||
# logits for the chosen and rejected samples from model
|
||||
logits_dict = {
|
||||
"eval_logits/chosen": metrics["eval_logits/chosen"],
|
||||
"eval_logits/rejected": metrics["eval_logits/rejected"],
|
||||
}
|
||||
logits = tuple(v.unsqueeze(dim=0) for k, v in logits_dict.items() if k not in ignore_keys)
|
||||
logits = torch.stack(logits).mean(axis=1).to(self.accelerator.device)
|
||||
labels = torch.zeros(logits.shape[0], device=self.accelerator.device)
|
||||
|
||||
return (loss.detach(), logits, labels)
|
||||
|
||||
def store_metrics(self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train") -> None:
|
||||
for key, value in metrics.items():
|
||||
self._stored_metrics[train_eval][key].append(value)
|
||||
|
||||
def evaluation_loop(
|
||||
self,
|
||||
dataloader: DataLoader,
|
||||
description: str,
|
||||
prediction_loss_only: Optional[bool] = None,
|
||||
ignore_keys: Optional[List[str]] = None,
|
||||
metric_key_prefix: str = "eval",
|
||||
) -> EvalLoopOutput:
|
||||
"""
|
||||
Overriding built-in evaluation loop to store metrics for each batch.
|
||||
Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
|
||||
|
||||
Works both with or without labels.
|
||||
"""
|
||||
|
||||
# Sample and save to game log if requested (for one batch to save time)
|
||||
if self.generate_during_eval:
|
||||
# Generate random indices within the range of the total number of samples
|
||||
num_samples = len(dataloader.dataset)
|
||||
random_indices = random.sample(range(num_samples), k=self.args.eval_batch_size)
|
||||
|
||||
# Use dataloader.dataset.select to get the random batch without iterating over the DataLoader
|
||||
random_batch_dataset = dataloader.dataset.select(random_indices)
|
||||
random_batch = self.data_collator(random_batch_dataset)
|
||||
random_batch = self._prepare_inputs(random_batch)
|
||||
|
||||
policy_output_decoded = self.get_batch_samples(self.model, random_batch)
|
||||
|
||||
self.log(
|
||||
{
|
||||
"game_log": wandb.Table(
|
||||
columns=["Prompt", "Policy"],
|
||||
rows=[
|
||||
[prompt, pol[len(prompt) :]]
|
||||
for prompt, pol in zip(random_batch["prompt"], policy_output_decoded)
|
||||
],
|
||||
)
|
||||
}
|
||||
)
|
||||
self.state.log_history.pop()
|
||||
|
||||
# Base evaluation
|
||||
initial_output = super().evaluation_loop(
|
||||
dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix
|
||||
)
|
||||
|
||||
return initial_output
|
||||
|
||||
def log(self, logs: Dict[str, float]) -> None:
|
||||
"""
|
||||
Log `logs` on the various objects watching training, including stored metrics.
|
||||
|
||||
Args:
|
||||
logs (`Dict[str, float]`):
|
||||
The values to log.
|
||||
"""
|
||||
# logs either has 'loss' or 'eval_loss'
|
||||
train_eval = "train" if "loss" in logs else "eval"
|
||||
# Add averaged stored metrics to logs
|
||||
for key, metrics in self._stored_metrics[train_eval].items():
|
||||
logs[key] = torch.tensor(metrics).mean().item()
|
||||
del self._stored_metrics[train_eval]
|
||||
return super().log(logs)
|
||||
|
||||
def _shift_right(self, input_ids):
|
||||
if self.decoder_start_token_id is None:
|
||||
raise ValueError(
|
||||
"model.config.decoder_start_token_id has to be defined. It is usually set to the pad_token_id."
|
||||
)
|
||||
|
||||
# shift inputs to the right
|
||||
if is_torch_fx_proxy(input_ids):
|
||||
# Item assignment is not supported natively for proxies.
|
||||
shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), self.decoder_start_token_id)
|
||||
shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
|
||||
else:
|
||||
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
||||
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
|
||||
shifted_input_ids[..., 0] = self.decoder_start_token_id
|
||||
|
||||
if self.pad_token_id is None:
|
||||
raise ValueError("model.config.pad_token_id has to be defined.")
|
||||
# replace possible -100 values in labels by `pad_token_id`
|
||||
shifted_input_ids.masked_fill_(shifted_input_ids == -100, self.pad_token_id)
|
||||
|
||||
return shifted_input_ids
|
||||
|
||||
@wraps(Trainer.push_to_hub)
|
||||
def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str:
|
||||
"""
|
||||
Overwrite the `push_to_hub` method in order to force-add the tag "orpo" when pushing the
|
||||
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
||||
"""
|
||||
kwargs = trl_sanitze_kwargs_for_tagging(model=self.model, tag_names=self._tag_names, kwargs=kwargs)
|
||||
|
||||
return super().push_to_hub(commit_message=commit_message, blocking=blocking, **kwargs)
|
@ -53,7 +53,12 @@ from ..core import (
|
||||
stats_to_np,
|
||||
)
|
||||
from ..import_utils import is_npu_available, is_torch_greater_2_0, is_xpu_available
|
||||
from ..models import SUPPORTED_ARCHITECTURES, PreTrainedModelWrapper, create_reference_model
|
||||
from ..models import (
|
||||
SUPPORTED_ARCHITECTURES,
|
||||
PreTrainedModelWrapper,
|
||||
create_reference_model,
|
||||
unwrap_model_for_generation,
|
||||
)
|
||||
from . import AdaptiveKLController, BaseTrainer, FixedKLController, PPOConfig, RunningMoments
|
||||
|
||||
|
||||
@ -470,15 +475,14 @@ class PPOTrainer(BaseTrainer):
|
||||
**generation_kwargs,
|
||||
)
|
||||
if generate_ref_response:
|
||||
with self.optional_peft_ctx():
|
||||
ref_response = self._generate_batched(
|
||||
ref_model,
|
||||
query_tensor,
|
||||
length_sampler=length_sampler,
|
||||
batch_size=batch_size,
|
||||
return_prompt=return_prompt,
|
||||
**generation_kwargs,
|
||||
)
|
||||
ref_response = self._generate_batched(
|
||||
ref_model,
|
||||
query_tensor,
|
||||
length_sampler=length_sampler,
|
||||
batch_size=batch_size,
|
||||
return_prompt=return_prompt,
|
||||
**generation_kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
if len(query_tensor.shape) == 2:
|
||||
@ -488,12 +492,17 @@ class PPOTrainer(BaseTrainer):
|
||||
|
||||
if length_sampler is not None:
|
||||
generation_kwargs["max_new_tokens"] = length_sampler()
|
||||
response = self.accelerator.unwrap_model(self.model).generate(
|
||||
input_ids=query_tensor.unsqueeze(dim=0), **generation_kwargs
|
||||
)
|
||||
|
||||
with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
|
||||
response = unwrapped_model.generate(input_ids=query_tensor.unsqueeze(dim=0), **generation_kwargs)
|
||||
|
||||
if generate_ref_response:
|
||||
with self.optional_peft_ctx():
|
||||
ref_response = ref_model.generate(input_ids=query_tensor.unsqueeze(dim=0), **generation_kwargs)
|
||||
with unwrap_model_for_generation(
|
||||
ref_model, self.accelerator, is_peft_model=self.is_peft_model
|
||||
) as unwrapped_model:
|
||||
ref_response = unwrapped_model.generate(
|
||||
input_ids=query_tensor.unsqueeze(dim=0), **generation_kwargs
|
||||
)
|
||||
|
||||
if not return_prompt and not self.is_encoder_decoder:
|
||||
response = response[:, query_tensor.shape[0] :]
|
||||
@ -543,7 +552,8 @@ class PPOTrainer(BaseTrainer):
|
||||
return_tensors="pt",
|
||||
).to(self.current_device)
|
||||
|
||||
generations = self.accelerator.unwrap_model(model).generate(**padded_inputs, **generation_kwargs)
|
||||
with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model:
|
||||
generations = unwrapped_model.generate(**padded_inputs, **generation_kwargs)
|
||||
|
||||
for generation, mask in zip(generations, padded_inputs["attention_mask"]):
|
||||
if not self.is_encoder_decoder:
|
||||
|
@ -137,7 +137,7 @@ class RewardTrainer(Trainer):
|
||||
inspect.signature(prepare_model_for_kbit_training).parameters
|
||||
)
|
||||
|
||||
preprare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
|
||||
prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
|
||||
|
||||
if not _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
|
||||
warnings.warn(
|
||||
@ -145,9 +145,9 @@ class RewardTrainer(Trainer):
|
||||
"please update to the latest version of peft to use `gradient_checkpointing_kwargs`."
|
||||
)
|
||||
elif _supports_gc_kwargs and args.gradient_checkpointing_kwargs is not None:
|
||||
preprare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
|
||||
prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
|
||||
|
||||
model = prepare_model_for_kbit_training(model, **preprare_model_kwargs)
|
||||
model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
|
||||
|
||||
model = get_peft_model(model, peft_config)
|
||||
|
||||
@ -196,17 +196,17 @@ class RewardTrainer(Trainer):
|
||||
else:
|
||||
self.use_reward_data_collator = False
|
||||
super().__init__(
|
||||
model,
|
||||
args,
|
||||
data_collator,
|
||||
train_dataset,
|
||||
eval_dataset,
|
||||
tokenizer,
|
||||
model_init,
|
||||
compute_metrics,
|
||||
callbacks,
|
||||
optimizers,
|
||||
preprocess_logits_for_metrics,
|
||||
model=model,
|
||||
args=args,
|
||||
data_collator=data_collator,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
tokenizer=tokenizer,
|
||||
model_init=model_init,
|
||||
compute_metrics=compute_metrics,
|
||||
callbacks=callbacks,
|
||||
optimizers=optimizers,
|
||||
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
||||
)
|
||||
|
||||
# Add tags for models that have been loaded with the correct transformers version
|
||||
|
@ -17,6 +17,7 @@ import warnings
|
||||
from functools import wraps
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from accelerate.state import PartialState
|
||||
@ -42,6 +43,7 @@ from ..import_utils import is_peft_available
|
||||
from .utils import (
|
||||
ConstantLengthDataset,
|
||||
DataCollatorForCompletionOnlyLM,
|
||||
RichProgressCallback,
|
||||
neftune_post_forward_hook,
|
||||
peft_module_casting_to_bf16,
|
||||
trl_sanitze_kwargs_for_tagging,
|
||||
@ -116,6 +118,8 @@ class SFTTrainer(Trainer):
|
||||
Dict of Optional kwargs to pass when instantiating the model from a string
|
||||
dataset_kwargs: (`Optional[Dict]`, *optional*):
|
||||
Dict of Optional kwargs to pass when creating packed or non-packed datasets
|
||||
eval_packing: (`Optional[bool]`, *optional*):
|
||||
Whether to pack the eval dataset as well. Defaults to `packing` if `None` is passed.
|
||||
"""
|
||||
|
||||
_tag_names = ["trl", "sft"]
|
||||
@ -124,7 +128,7 @@ class SFTTrainer(Trainer):
|
||||
self,
|
||||
model: Optional[Union[PreTrainedModel, nn.Module, str]] = None,
|
||||
args: Optional[TrainingArguments] = None,
|
||||
data_collator: Optional[DataCollator] = None,
|
||||
data_collator: Optional[DataCollator] = None, # type: ignore
|
||||
train_dataset: Optional[Dataset] = None,
|
||||
eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
|
||||
tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
||||
@ -146,6 +150,7 @@ class SFTTrainer(Trainer):
|
||||
neftune_noise_alpha: Optional[float] = None,
|
||||
model_init_kwargs: Optional[Dict] = None,
|
||||
dataset_kwargs: Optional[Dict] = None,
|
||||
eval_packing: Optional[bool] = None,
|
||||
):
|
||||
if model_init_kwargs is None:
|
||||
model_init_kwargs = {}
|
||||
@ -183,15 +188,26 @@ class SFTTrainer(Trainer):
|
||||
inspect.signature(prepare_model_for_kbit_training).parameters
|
||||
)
|
||||
gradient_checkpointing_kwargs = getattr(args, "gradient_checkpointing_kwargs", None) or {}
|
||||
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
|
||||
preprare_model_kwargs = {
|
||||
is_sharded_qlora = False
|
||||
# Below is to support QLoRA + FSDP / DS-Zero3 - one should never call
|
||||
# peft_module_casting_to_bf16 or prepare_model_for_kbit_training when doing
|
||||
# QLoRA + FSDP / DS-Zero3
|
||||
if getattr(model, "is_loaded_in_4bit", False):
|
||||
for _, param in model.named_parameters():
|
||||
if param.__class__.__name__ == "Params4bit":
|
||||
is_sharded_qlora = param.data.device.type == "cpu"
|
||||
break
|
||||
if getattr(model, "is_loaded_in_8bit", False) or (
|
||||
getattr(model, "is_loaded_in_4bit", False) and not is_sharded_qlora
|
||||
):
|
||||
prepare_model_kwargs = {
|
||||
"use_gradient_checkpointing": getattr(args, "gradient_checkpointing", False)
|
||||
}
|
||||
|
||||
if _support_gc_kwargs:
|
||||
preprare_model_kwargs["gradient_checkpointing_kwargs"] = gradient_checkpointing_kwargs
|
||||
prepare_model_kwargs["gradient_checkpointing_kwargs"] = gradient_checkpointing_kwargs
|
||||
|
||||
model = prepare_model_for_kbit_training(model, **preprare_model_kwargs)
|
||||
model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
|
||||
|
||||
if args is not None:
|
||||
args = dataclasses.replace(args, gradient_checkpointing=False)
|
||||
@ -210,7 +226,12 @@ class SFTTrainer(Trainer):
|
||||
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
||||
|
||||
model = get_peft_model(model, peft_config)
|
||||
if args is not None and args.bf16 and getattr(model, "is_loaded_in_4bit", False):
|
||||
if (
|
||||
args is not None
|
||||
and args.bf16
|
||||
and getattr(model, "is_loaded_in_4bit", False)
|
||||
and not is_sharded_qlora
|
||||
):
|
||||
peft_module_casting_to_bf16(model)
|
||||
|
||||
if tokenizer is None:
|
||||
@ -274,11 +295,14 @@ class SFTTrainer(Trainer):
|
||||
if eval_dataset is not None:
|
||||
_multiple = isinstance(eval_dataset, dict)
|
||||
_eval_datasets = eval_dataset if _multiple else {"singleton": eval_dataset}
|
||||
|
||||
eval_packing = packing if eval_packing is None else eval_packing
|
||||
|
||||
for _eval_dataset_name, _eval_dataset in _eval_datasets.items():
|
||||
_eval_datasets[_eval_dataset_name] = self._prepare_dataset(
|
||||
_eval_dataset,
|
||||
tokenizer,
|
||||
packing,
|
||||
eval_packing,
|
||||
dataset_text_field,
|
||||
max_seq_length,
|
||||
formatting_func,
|
||||
@ -322,6 +346,12 @@ class SFTTrainer(Trainer):
|
||||
elif self.args.max_steps == -1 and packing:
|
||||
self.train_dataset.infinite = False
|
||||
|
||||
if any(isinstance(callback, RichProgressCallback) for callback in self.callback_handler.callbacks):
|
||||
for callback in self.callback_handler.callbacks:
|
||||
# Remove the PrinterCallback to avoid duplicated prints in case we passed a `RichProgressCallback`
|
||||
if callback.__class__.__name__ == "PrinterCallback":
|
||||
self.callback_handler.pop_callback(callback)
|
||||
|
||||
@wraps(Trainer.train)
|
||||
def train(self, *args, **kwargs):
|
||||
# Activate neftune right before training.
|
||||
@ -367,12 +397,27 @@ class SFTTrainer(Trainer):
|
||||
remove_unused_columns=True,
|
||||
append_concat_token=True,
|
||||
add_special_tokens=True,
|
||||
skip_prepare_dataset=False,
|
||||
):
|
||||
if dataset is None:
|
||||
raise ValueError("The dataset should not be None")
|
||||
|
||||
if skip_prepare_dataset:
|
||||
return dataset
|
||||
|
||||
# If the dataset is already preprocessed (tokenized), return as-is. Only works if dataset is
|
||||
# a datasets.Dataset or datasets.IterableDataset -- not for torch Dataset
|
||||
column_names = (
|
||||
dataset.column_names if isinstance(dataset, (datasets.Dataset, datasets.IterableDataset)) else None
|
||||
)
|
||||
if column_names and "input_ids" in column_names:
|
||||
return dataset
|
||||
|
||||
# check if torch dataset / dataloader and do nothing
|
||||
if isinstance(dataset, (torch.utils.data.IterableDataset, torch.utils.data.Dataset, ConstantLengthDataset)):
|
||||
# see https://github.com/huggingface/trl/pull/1468 for why datasets.IterableDataset needs a separate check
|
||||
if isinstance(
|
||||
dataset, (torch.utils.data.IterableDataset, torch.utils.data.Dataset, ConstantLengthDataset)
|
||||
) and not isinstance(dataset, datasets.IterableDataset):
|
||||
return dataset
|
||||
|
||||
if not packing:
|
||||
@ -484,6 +529,9 @@ class SFTTrainer(Trainer):
|
||||
add_special_tokens=add_special_tokens,
|
||||
)
|
||||
|
||||
if isinstance(dataset, datasets.IterableDataset):
|
||||
return constant_length_iterator
|
||||
|
||||
def data_generator(constant_length_iterator):
|
||||
yield from constant_length_iterator
|
||||
|
||||
|
@ -20,9 +20,15 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
from accelerate import PartialState
|
||||
from rich.console import Console, Group
|
||||
from rich.live import Live
|
||||
from rich.panel import Panel
|
||||
from rich.progress import Progress
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from torch.utils.data import IterableDataset
|
||||
from transformers import BitsAndBytesConfig, DataCollatorForLanguageModeling, PreTrainedTokenizerBase
|
||||
from transformers.trainer import TrainerCallback
|
||||
from transformers.trainer_utils import has_length
|
||||
|
||||
from ..import_utils import is_peft_available, is_unsloth_available, is_xpu_available
|
||||
from ..trainer.model_config import ModelConfig
|
||||
@ -319,7 +325,7 @@ class DPODataCollatorWithPadding:
|
||||
padding_value = self.pad_token_id
|
||||
elif k.endswith("_attention_mask"):
|
||||
padding_value = 0
|
||||
elif (k.startswith("chosen")) or (k.startswith("rejected")) or ("decoder" in k):
|
||||
elif k.startswith(("chosen", "rejected", "completion")) or ("decoder" in k):
|
||||
padding_value = self.label_pad_token_id
|
||||
else:
|
||||
raise ValueError(f"Unexpected key in batch '{k}'")
|
||||
@ -715,14 +721,97 @@ def get_peft_config(model_config: ModelConfig) -> "Optional[PeftConfig]":
|
||||
if model_config.use_peft is False:
|
||||
return None
|
||||
|
||||
if not is_peft_available():
|
||||
raise ValueError(
|
||||
"You need to have PEFT library installed in your environment, make sure to install `peft`. "
|
||||
"Make sure to run `pip install -U peft`."
|
||||
)
|
||||
|
||||
peft_config = LoraConfig(
|
||||
r=model_config.lora_r,
|
||||
lora_alpha=model_config.lora_alpha,
|
||||
lora_dropout=model_config.lora_dropout,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
task_type=model_config.lora_task_type,
|
||||
target_modules=model_config.lora_target_modules,
|
||||
modules_to_save=model_config.lora_modules_to_save,
|
||||
)
|
||||
|
||||
return peft_config
|
||||
|
||||
|
||||
class RichProgressCallback(TrainerCallback):
|
||||
"""
|
||||
A [`TrainerCallback`] that displays the progress of training or evaluation using Rich.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.training_bar = None
|
||||
self.prediction_bar = None
|
||||
|
||||
self.training_task_id = None
|
||||
self.prediction_task_id = None
|
||||
|
||||
self.rich_group = None
|
||||
self.rich_console = None
|
||||
|
||||
self.training_status = None
|
||||
self.current_step = None
|
||||
|
||||
def on_train_begin(self, args, state, control, **kwargs):
|
||||
if state.is_world_process_zero:
|
||||
self.training_bar = Progress()
|
||||
self.prediction_bar = Progress()
|
||||
|
||||
self.rich_console = Console()
|
||||
|
||||
self.training_status = self.rich_console.status("Nothing to log yet ...")
|
||||
|
||||
self.rich_group = Live(Panel(Group(self.training_bar, self.prediction_bar, self.training_status)))
|
||||
self.rich_group.start()
|
||||
|
||||
self.training_task_id = self.training_bar.add_task("[blue]Training the model", total=state.max_steps)
|
||||
self.current_step = 0
|
||||
|
||||
def on_step_end(self, args, state, control, **kwargs):
|
||||
if state.is_world_process_zero:
|
||||
self.training_bar.update(self.training_task_id, advance=state.global_step - self.current_step, update=True)
|
||||
self.current_step = state.global_step
|
||||
|
||||
def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs):
|
||||
if state.is_world_process_zero and has_length(eval_dataloader):
|
||||
if self.prediction_task_id is None:
|
||||
self.prediction_task_id = self.prediction_bar.add_task(
|
||||
"[blue]Predicting on the evaluation dataset", total=len(eval_dataloader)
|
||||
)
|
||||
self.prediction_bar.update(self.prediction_task_id, advance=1, update=True)
|
||||
|
||||
def on_evaluate(self, args, state, control, **kwargs):
|
||||
if state.is_world_process_zero:
|
||||
if self.prediction_task_id is not None:
|
||||
self.prediction_bar.remove_task(self.prediction_task_id)
|
||||
self.prediction_task_id = None
|
||||
|
||||
def on_predict(self, args, state, control, **kwargs):
|
||||
if state.is_world_process_zero:
|
||||
if self.prediction_task_id is not None:
|
||||
self.prediction_bar.remove_task(self.prediction_task_id)
|
||||
self.prediction_task_id = None
|
||||
|
||||
def on_log(self, args, state, control, logs=None, **kwargs):
|
||||
if state.is_world_process_zero and self.training_bar is not None:
|
||||
_ = logs.pop("total_flos", None)
|
||||
self.training_status.update(f"[bold green]Status = {str(logs)}")
|
||||
|
||||
def on_train_end(self, args, state, control, **kwargs):
|
||||
if state.is_world_process_zero:
|
||||
self.rich_group.stop()
|
||||
|
||||
self.training_bar = None
|
||||
self.prediction_bar = None
|
||||
self.training_task_id = None
|
||||
self.prediction_task_id = None
|
||||
self.rich_group = None
|
||||
self.rich_console = None
|
||||
self.training_status = None
|
||||
self.current_step = None
|
||||
|
Reference in New Issue
Block a user