mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
🎁 RewardTrainer
refactor (#4093)
Co-authored-by: juejuezi <juejuezi.git@foxmail.com> Co-authored-by: Yi Shi <96773624+singing-cat@users.noreply.github.com> Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
This commit is contained in:
committed by
GitHub
parent
ebb8899f5d
commit
da209f89fc
14
README.md
14
README.md
@ -136,23 +136,13 @@ trainer.train()
|
|||||||
Here is a basic example of how to use the [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer):
|
Here is a basic example of how to use the [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer):
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from trl import RewardConfig, RewardTrainer
|
from trl import RewardTrainer
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
|
|
||||||
model = AutoModelForSequenceClassification.from_pretrained(
|
|
||||||
"Qwen/Qwen2.5-0.5B-Instruct", num_labels=1
|
|
||||||
)
|
|
||||||
model.config.pad_token_id = tokenizer.pad_token_id
|
|
||||||
|
|
||||||
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
|
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
|
||||||
|
|
||||||
training_args = RewardConfig(output_dir="Qwen2.5-0.5B-Reward", per_device_train_batch_size=2)
|
|
||||||
trainer = RewardTrainer(
|
trainer = RewardTrainer(
|
||||||
args=training_args,
|
model="Qwen/Qwen2.5-0.5B-Instruct",
|
||||||
model=model,
|
|
||||||
processing_class=tokenizer,
|
|
||||||
train_dataset=dataset,
|
train_dataset=dataset,
|
||||||
)
|
)
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
@ -9,6 +9,7 @@ Currently supported commands are:
|
|||||||
- `trl dpo`: fine-tune a LLM with DPO
|
- `trl dpo`: fine-tune a LLM with DPO
|
||||||
- `trl grpo`: fine-tune a LLM with GRPO
|
- `trl grpo`: fine-tune a LLM with GRPO
|
||||||
- `trl kto`: fine-tune a LLM with KTO
|
- `trl kto`: fine-tune a LLM with KTO
|
||||||
|
- `trl reward`: train a Reward Model
|
||||||
- `trl rloo`: fine-tune a LLM with RLOO
|
- `trl rloo`: fine-tune a LLM with RLOO
|
||||||
- `trl sft`: fine-tune a LLM with SFT
|
- `trl sft`: fine-tune a LLM with SFT
|
||||||
|
|
||||||
@ -41,6 +42,15 @@ trl dpo \
|
|||||||
--dataset_name anthropic/hh-rlhf
|
--dataset_name anthropic/hh-rlhf
|
||||||
```
|
```
|
||||||
|
|
||||||
|
</hfoption>
|
||||||
|
<hfoption id="Reward">
|
||||||
|
|
||||||
|
```bash
|
||||||
|
trl reward \
|
||||||
|
--model_name_or_path Qwen/Qwen2.5-0.5B \
|
||||||
|
--dataset_name trl-lib/ultrafeedback_binarized
|
||||||
|
```
|
||||||
|
|
||||||
</hfoption>
|
</hfoption>
|
||||||
</hfoptions>
|
</hfoptions>
|
||||||
|
|
||||||
@ -78,6 +88,21 @@ Launch with:
|
|||||||
trl dpo --config dpo_config.yaml
|
trl dpo --config dpo_config.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
|
</hfoption>
|
||||||
|
<hfoption id="Reward">
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# reward_config.yaml
|
||||||
|
model_name_or_path: Qwen/Qwen2.5-0.5B
|
||||||
|
dataset_name: trl-lib/ultrafeedback_binarized
|
||||||
|
```
|
||||||
|
|
||||||
|
Launch with:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
trl reward --config reward_config.yaml
|
||||||
|
```
|
||||||
|
|
||||||
</hfoption>
|
</hfoption>
|
||||||
</hfoptions>
|
</hfoptions>
|
||||||
|
|
||||||
@ -138,6 +163,33 @@ Launch with:
|
|||||||
```bash
|
```bash
|
||||||
trl dpo --config dpo_config.yaml
|
trl dpo --config dpo_config.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
|
</hfoption>
|
||||||
|
<hfoption id="Reward inline">
|
||||||
|
|
||||||
|
```bash
|
||||||
|
trl reward \
|
||||||
|
--model_name_or_path Qwen/Qwen2.5-0.5B \
|
||||||
|
--dataset_name trl-lib/ultrafeedback_binarized \
|
||||||
|
--num_processes 4
|
||||||
|
```
|
||||||
|
|
||||||
|
</hfoption>
|
||||||
|
<hfoption id="Reward w/ config file">
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# reward_config.yaml
|
||||||
|
model_name_or_path: Qwen/Qwen2.5-0.5B
|
||||||
|
dataset_name: trl-lib/ultrafeedback_binarized
|
||||||
|
num_processes: 4
|
||||||
|
```
|
||||||
|
|
||||||
|
Launch with:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
trl reward --config reward_config.yaml
|
||||||
|
```
|
||||||
|
|
||||||
</hfoption>
|
</hfoption>
|
||||||
</hfoptions>
|
</hfoptions>
|
||||||
|
|
||||||
@ -217,6 +269,33 @@ Launch with:
|
|||||||
```bash
|
```bash
|
||||||
trl dpo --config dpo_config.yaml
|
trl dpo --config dpo_config.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
|
</hfoption>
|
||||||
|
<hfoption id="Reward inline">
|
||||||
|
|
||||||
|
```bash
|
||||||
|
trl reward \
|
||||||
|
--model_name_or_path Qwen/Qwen2.5-0.5B \
|
||||||
|
--dataset_name trl-lib/ultrafeedback_binarized \
|
||||||
|
--accelerate_config zero2 # or path/to/my/accelerate/config.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
</hfoption>
|
||||||
|
<hfoption id="Reward w/ config file">
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# reward_config.yaml
|
||||||
|
model_name_or_path: Qwen/Qwen2.5-0.5B
|
||||||
|
dataset_name: trl-lib/ultrafeedback_binarized
|
||||||
|
accelerate_config: zero2 # or path/to/my/accelerate/config.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
Launch with:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
trl reward --config reward_config.yaml
|
||||||
|
```
|
||||||
|
|
||||||
</hfoption>
|
</hfoption>
|
||||||
</hfoptions>
|
</hfoptions>
|
||||||
|
|
||||||
@ -224,7 +303,7 @@ trl dpo --config dpo_config.yaml
|
|||||||
|
|
||||||
You can use dataset mixtures to combine multiple datasets into a single training dataset. This is useful for training on diverse data sources or when you want to mix different types of data.
|
You can use dataset mixtures to combine multiple datasets into a single training dataset. This is useful for training on diverse data sources or when you want to mix different types of data.
|
||||||
|
|
||||||
<hfoptions id="accelerate_config">
|
<hfoptions id="dataset_mixtures">
|
||||||
<hfoption id="SFT">
|
<hfoption id="SFT">
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
@ -258,6 +337,23 @@ Launch with:
|
|||||||
trl dpo --config dpo_config.yaml
|
trl dpo --config dpo_config.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
|
</hfoption>
|
||||||
|
<hfoption id="Reward">
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# reward_config.yaml
|
||||||
|
model_name_or_path: Qwen/Qwen2.5-0.5B
|
||||||
|
datasets:
|
||||||
|
- path: trl-lib/tldr-preference
|
||||||
|
- path: trl-lib/lm-human-preferences-sentiment
|
||||||
|
```
|
||||||
|
|
||||||
|
Launch with:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
trl reward --config reward_config.yaml
|
||||||
|
```
|
||||||
|
|
||||||
</hfoption>
|
</hfoption>
|
||||||
</hfoptions>
|
</hfoptions>
|
||||||
|
|
||||||
|
@ -533,3 +533,53 @@ training_args = CPOConfig(
|
|||||||
...
|
...
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Reward Modeling
|
||||||
|
|
||||||
|
Papers relating to the [`RewardTrainer`]
|
||||||
|
|
||||||
|
### Helping or Herding? Reward Model Ensembles Mitigate but do not Eliminate Reward Hacking
|
||||||
|
|
||||||
|
**📜 Paper**: https://huggingface.co/papers/2312.09244
|
||||||
|
|
||||||
|
This paper proposed an auxiliary loss function designed to directly learn a centered reward model. This auxiliary loss minimizes the squared sum of the rewards, encouraging the model to naturally produce mean-zero outputs and thereby resolving the issue of underdetermination.
|
||||||
|
|
||||||
|
$$
|
||||||
|
\mathcal{L}(\theta) = - \mathbb{E}_{(x,y^+,y^-) \sim \mathcal{D}} \left[ \log \sigma(r_\theta(x, y^+) - r_\theta(x, y^-)) \textcolor{red}{- \eta \cdot (r_\theta(x, y^+) + r_\theta(x, y^-))^2} \right].
|
||||||
|
$$
|
||||||
|
|
||||||
|
To use this auxiliary loss with [`RewardTrainer`], you can use the `center_rewards_coefficient` argument in [`RewardConfig`] as follows:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from trl import RewardConfig
|
||||||
|
|
||||||
|
training_args = RewardConfig(
|
||||||
|
center_rewards_coefficient=0.01, # η in the paper
|
||||||
|
...
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Llama 2: Open Foundation and Fine-Tuned Chat Models
|
||||||
|
|
||||||
|
**📜 Paper**: https://huggingface.co/papers/2307.09288
|
||||||
|
|
||||||
|
In this paper, the authors propose to leverage their preference ratings being decomposed as a scale of four points (e.g., _significantly better_) to provide more informative feedback to the reward model. This is done by adding a margin to the loss function, which encourages the reward model to assign larger gaps in scores for pairs with higher preference ratings.
|
||||||
|
|
||||||
|
$$
|
||||||
|
\mathcal{L}(\theta) = - \mathbb{E}_{(x,y^+,y^-,\textcolor{red}{m}) \sim \mathcal{D}} \left[ \log \sigma(r_\theta(x, y^+) - r_\theta(x, y^-) \textcolor{red}{- m}) \right].
|
||||||
|
$$
|
||||||
|
|
||||||
|
You can add a margin to the loss by adding a `margin` column to the dataset. The following example shows how to set up a the "Margin Small" setting of the paper.
|
||||||
|
|
||||||
|
```python
|
||||||
|
def add_margin(example):
|
||||||
|
preference_to_margin = {
|
||||||
|
"significantly better": 1.0,
|
||||||
|
"better": 2.0/3.0,
|
||||||
|
"slightly better": 1.0/3.0,
|
||||||
|
"negligibly better / unsure": 0.0,
|
||||||
|
}
|
||||||
|
return {"margin": preference_to_margin[example["preference_label"]]}
|
||||||
|
|
||||||
|
dataset = dataset.map(add_margin)
|
||||||
|
```
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
# Quickstart
|
# Quickstart
|
||||||
|
|
||||||
TRL is a comprehensive library for post-training foundation models using techniques like Supervised Fine-Tuning (SFT), Group Relative Policy Optimization (GRPO), Direct Preference Optimization (DPO).
|
TRL is a comprehensive library for post-training foundation models using techniques like Supervised Fine-Tuning (SFT), Group Relative Policy Optimization (GRPO), Direct Preference Optimization (DPO).
|
||||||
|
|
||||||
## Quick Examples
|
## Quick Examples
|
||||||
|
|
||||||
@ -51,6 +51,21 @@ trainer = DPOTrainer(
|
|||||||
trainer.train()
|
trainer.train()
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Reward Modeling
|
||||||
|
|
||||||
|
```python
|
||||||
|
from trl import RewardTrainer
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
|
||||||
|
|
||||||
|
trainer = RewardTrainer(
|
||||||
|
model="Qwen/Qwen2.5-0.5B-Instruct",
|
||||||
|
train_dataset=dataset,
|
||||||
|
)
|
||||||
|
trainer.train()
|
||||||
|
```
|
||||||
|
|
||||||
## Command Line Interface
|
## Command Line Interface
|
||||||
|
|
||||||
Skip the code entirely - train directly from your terminal:
|
Skip the code entirely - train directly from your terminal:
|
||||||
@ -63,6 +78,10 @@ trl sft --model_name_or_path Qwen/Qwen2.5-0.5B \
|
|||||||
# DPO: Align with preferences
|
# DPO: Align with preferences
|
||||||
trl dpo --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
|
trl dpo --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
|
||||||
--dataset_name trl-lib/ultrafeedback_binarized
|
--dataset_name trl-lib/ultrafeedback_binarized
|
||||||
|
|
||||||
|
# Reward: Train a reward model
|
||||||
|
trl reward --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
|
||||||
|
--dataset_name trl-lib/ultrafeedback_binarized
|
||||||
```
|
```
|
||||||
|
|
||||||
## What's Next?
|
## What's Next?
|
||||||
|
@ -2,84 +2,225 @@
|
|||||||
|
|
||||||
[](https://huggingface.co/models?other=reward-trainer,trl)
|
[](https://huggingface.co/models?other=reward-trainer,trl)
|
||||||
|
|
||||||
TRL supports custom reward modeling for anyone to perform reward modeling on their dataset and model.
|
## Overview
|
||||||
|
|
||||||
Check out a complete flexible example at [`examples/scripts/reward_modeling.py`](https://github.com/huggingface/trl/tree/main/examples/scripts/reward_modeling.py).
|
TRL supports the Outcome-supervised Reward Modeling (ORM) Trainer for training reward models.
|
||||||
|
|
||||||
## Expected dataset type
|
This post-training method was contributed by [Younes Belkada](https://huggingface.co/ybelkada).
|
||||||
|
|
||||||
The [`RewardTrainer`] requires a [*implicit prompt* preference dataset](dataset_formats#preference). It means that the dataset should only contain the columns `"chosen"` and `"rejected"` (and not `"prompt"`).
|
## Quick start
|
||||||
The [`RewardTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
|
|
||||||
|
|
||||||
You can also use a pretokenized dataset, in which case the dataset should contain the following columns: `input_ids_chosen`, `attention_mask_chosen`, `input_ids_rejected` and `attention_mask_rejected`.
|
This example demonstrates how to train a reward model using the [`RewardTrainer`] from TRL. We train a [Qwen 3 0.6B](https://huggingface.co/Qwen/Qwen3-0.6B) model on the [UltraFeedback dataset](https://huggingface.co/datasets/trl-lib/ultrafeedback_binarized), large-scale, fine-grained, diverse preference dataset.
|
||||||
|
|
||||||
## Using the `RewardTrainer`
|
|
||||||
|
|
||||||
After preparing your dataset, you can use the [`RewardTrainer`] in the same way as the `Trainer` class from 🤗 Transformers.
|
|
||||||
You should pass an `AutoModelForSequenceClassification` model to the [`RewardTrainer`], along with a [`RewardConfig`] which configures the hyperparameters of the training.
|
|
||||||
|
|
||||||
### Leveraging 🤗 PEFT to train a reward model
|
|
||||||
|
|
||||||
Just pass a `peft_config` in the keyword arguments of [`RewardTrainer`], and the trainer should automatically take care of converting the model into a PEFT model!
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from peft import LoraConfig, TaskType
|
from trl import RewardTrainer
|
||||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
from datasets import load_dataset
|
||||||
from trl import RewardTrainer, RewardConfig
|
|
||||||
|
|
||||||
model = AutoModelForSequenceClassification.from_pretrained("gpt2")
|
|
||||||
peft_config = LoraConfig(
|
|
||||||
task_type=TaskType.SEQ_CLS,
|
|
||||||
inference_mode=False,
|
|
||||||
r=8,
|
|
||||||
lora_alpha=32,
|
|
||||||
lora_dropout=0.1,
|
|
||||||
)
|
|
||||||
|
|
||||||
...
|
|
||||||
|
|
||||||
trainer = RewardTrainer(
|
trainer = RewardTrainer(
|
||||||
model=model,
|
model="Qwen/Qwen3-0.6B",
|
||||||
args=training_args,
|
train_dataset=load_dataset("trl-lib/ultrafeedback_binarized", split="train"),
|
||||||
processing_class=tokenizer,
|
)
|
||||||
|
trainer.train()
|
||||||
|
```
|
||||||
|
|
||||||
|
<iframe src="https://trl-lib-trackio.hf.space/?project=trl-documentation&metrics=train*&sidebar=hidden&runs=reward_qwen3-0.6B_ultrafeedback2" style="width: 100%; min-width: 300px; max-width: 800px;" height="830" frameBorder="0"></iframe>
|
||||||
|
|
||||||
|
## Expected dataset type and format
|
||||||
|
|
||||||
|
[`RewardTrainer`] supports [preference](dataset_formats#preference) datasets type (both implicit and explicit prompt). The [`RewardTrainer`] is compatible with both [standard](dataset_formats#standard) and [conversational](dataset_formats#conversational) dataset formats. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Standard preference (implicit prompt)
|
||||||
|
{"chosen": "The sky is blue.",
|
||||||
|
"rejected": "The sky is green."}
|
||||||
|
|
||||||
|
# Conversational preference (implicit prompt)
|
||||||
|
{"chosen": [{"role": "user", "content": "What color is the sky?"},
|
||||||
|
{"role": "assistant", "content": "It is blue."}],
|
||||||
|
"rejected": [{"role": "user", "content": "What color is the sky?"},
|
||||||
|
{"role": "assistant", "content": "It is green."}]}
|
||||||
|
|
||||||
|
# Standard preference (explicit prompt)
|
||||||
|
{"prompt": "The sky is",
|
||||||
|
"chosen": " blue.",
|
||||||
|
"rejected": " green."}
|
||||||
|
|
||||||
|
# Conversational preference (explicit prompt)
|
||||||
|
{"prompt": [{"role": "user", "content": "What color is the sky?"}],
|
||||||
|
"chosen": [{"role": "assistant", "content": "It is blue."}],
|
||||||
|
"rejected": [{"role": "assistant", "content": "It is green."}]}
|
||||||
|
```
|
||||||
|
|
||||||
|
If your dataset is not in one of these formats, you can preprocess it to convert it into the expected format. Here is an example with the [lmarena-ai/arena-human-preference-55k](https://huggingface.co/datasets/lmarena-ai/arena-human-preference-55k) dataset:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from datasets import load_dataset
|
||||||
|
import json
|
||||||
|
|
||||||
|
dataset = load_dataset("lmarena-ai/arena-human-preference-55k")
|
||||||
|
|
||||||
|
# Filter out ties
|
||||||
|
dataset = dataset.filter(lambda example: example["winner_tie"] == 0)
|
||||||
|
|
||||||
|
# Create 'chosen' and 'rejected' fields based on the winner column
|
||||||
|
def response_a_b_to_chosen_rejected(example):
|
||||||
|
if example["winner_model_a"] == 1:
|
||||||
|
example["chosen"] = example["response_a"]
|
||||||
|
example["rejected"] = example["response_b"]
|
||||||
|
else:
|
||||||
|
example["chosen"] = example["response_b"]
|
||||||
|
example["rejected"] = example["response_a"]
|
||||||
|
return example
|
||||||
|
|
||||||
|
dataset = dataset.map(response_a_b_to_chosen_rejected)
|
||||||
|
|
||||||
|
# Convert to conversational format
|
||||||
|
def make_conversation(example):
|
||||||
|
prompt = json.loads(example["prompt"])[0] # '["What color is the sky?"]' -> "What color is the sky?"
|
||||||
|
chosen = json.loads(example["chosen"])[0]
|
||||||
|
rejected = json.loads(example["rejected"])[0]
|
||||||
|
return {
|
||||||
|
"chosen": [{"role": "user", "content": prompt}, {"role": "assistant", "content": chosen}],
|
||||||
|
"rejected": [{"role": "user", "content": prompt}, {"role": "assistant", "content": rejected}],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
dataset = dataset.map(make_conversation)
|
||||||
|
|
||||||
|
# Keep only necessary columns
|
||||||
|
dataset = dataset.select_columns(["chosen", "rejected"])
|
||||||
|
|
||||||
|
print(next(iter(dataset["train"])))
|
||||||
|
```
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"chosen": [
|
||||||
|
{"role": "user", "content": "Is it morally right to try to have a certain percentage of females on managerial positions?"},
|
||||||
|
{"role": "assistant", "content": "The question of whether it is morally right to aim for a certain percentage of females..."},
|
||||||
|
],
|
||||||
|
"rejected": [
|
||||||
|
{"role": "user", "content": "Is it morally right to try to have a certain percentage of females on managerial positions?"},
|
||||||
|
{"role": "assistant", "content": "As an AI, I don't have personal beliefs or opinions. However, ..."},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Looking deeper into the training method
|
||||||
|
|
||||||
|
Reward Models (RMs) are typically trained using supervised learning on datasets containing pairs of preferred and non-preferred responses. The goal is to learn a function that assigns higher scores to preferred responses, enabling the model to rank outputs based on preferences.
|
||||||
|
|
||||||
|
This section breaks down how reward modeling works in practice, covering the key steps: **preprocessing** and **loss computation**.
|
||||||
|
|
||||||
|
### Preprocessing and tokenization
|
||||||
|
|
||||||
|
During training, each example is expected to contain a **chosen** and **rejected** field. For more details on the expected formats, see [Dataset formats - Preference](dataset_formats#preference).
|
||||||
|
The [`RewardTrainer`] tokenizes each input using the model's tokenizer. If prompts and completions (chosen and rejected) are provided separately (explicit prompt case), they are concatenated before tokenization.
|
||||||
|
|
||||||
|
### Computing the loss
|
||||||
|
|
||||||
|
Let \\( x \\) be the input sequence (prompt) and \\( y^+ \\) and \\( y^- \\) be the chosen and rejected sequences respectively. Under the Bradley-Terry model ([Bradley & Terry, 1952](https://www.jstor.org/stable/2334029)), the probability that \\( y^+ \\) is preferred over \\( y^- \\) given a reward function \\( r \\) is \\( p(y^+ ≻ y^- |x) = \sigma(r(x, y^+)−r(x, y^-)) \\), where \\( σ \\) is the sigmoid function.
|
||||||
|
|
||||||
|
The reward model \\( r_\theta(x, y) \\) is trained to assign higher scores to preferred responses \\( y^+ \\) over non-preferred ones \\( y^- \\). The loss is then defined as the negative log-likelihood of the observed preferences:
|
||||||
|
|
||||||
|
$$
|
||||||
|
\mathcal{L}(\theta) = - \mathbb{E}_{(x,y^+,y^-) \sim \mathcal{D}} \left[ \log \sigma(r_\theta(x, y^+) - r_\theta(x, y^-)) \right].
|
||||||
|
$$
|
||||||
|
|
||||||
|
> [!TIP]
|
||||||
|
> The Bradley-Terry model is underdetermined, meaning that adding a constant to all rewards does not change the preference probabilities. To address this, [Helping or Herding? Reward Model Ensembles Mitigate but do not Eliminate Reward Hacking](https://huggingface.co/papers/2312.09244) proposes adding an auxiliary loss term that encourages the rewards to be centered around zero. This is controlled by the `center_rewards_coefficient` parameter in the [`RewardConfig`]. The recommended value is `1e-2`.
|
||||||
|
|
||||||
|
## Logged metrics
|
||||||
|
|
||||||
|
While training and evaluating we record the following reward metrics:
|
||||||
|
|
||||||
|
* `global_step`: The total number of optimizer steps taken so far.
|
||||||
|
* `epoch`: The current epoch number, based on dataset iteration.
|
||||||
|
* `num_tokens`: The total number of tokens processed so far.
|
||||||
|
* `loss`: The average loss over the last logging interval.
|
||||||
|
* `accuracy`: The proportion of correct predictions (i.e., the model assigned a higher score to the chosen response than to the rejected one) averaged over the last logging interval.
|
||||||
|
* `min_reward`: The minimum reward score assigned by the model. This value is averaged over the logging interval.
|
||||||
|
* `mean_reward`: The average reward score assigned by the model over the last logging interval.
|
||||||
|
* `max_reward`: The maximum reward score assigned by the model. This value is averaged over the logging interval.
|
||||||
|
* `margin`: The average margin (difference between chosen and rejected rewards) over the last logging interval.
|
||||||
|
* `learning_rate`: The current learning rate, which may change dynamically if a scheduler is used.
|
||||||
|
* `grad_norm`: The L2 norm of the gradients, computed before gradient clipping.
|
||||||
|
|
||||||
|
## Customization
|
||||||
|
|
||||||
|
### Model initialization
|
||||||
|
|
||||||
|
You can directly pass the kwargs of the [`~transformers.AutoModelForSequenceClassification.from_pretrained()`] method to the [`RewardConfig`]. For example, if you want to load a model in a different precision, analogous to
|
||||||
|
|
||||||
|
```python
|
||||||
|
model = AutoModelForSequenceClassification.from_pretrained("Qwen/Qwen3-0.6B", dtype=torch.bfloat16)
|
||||||
|
```
|
||||||
|
|
||||||
|
you can do so by passing the `model_init_kwargs={"dtype": torch.bfloat16}` argument to the [`RewardConfig`].
|
||||||
|
|
||||||
|
```python
|
||||||
|
from trl import RewardConfig
|
||||||
|
|
||||||
|
training_args = RewardConfig(
|
||||||
|
model_init_kwargs={"dtype": torch.bfloat16},
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
Note that all keyword arguments of [`~transformers.AutoModelForSequenceClassification.from_pretrained()`] are supported, except for `num_labels`, which is automatically set to 1.
|
||||||
|
|
||||||
|
### Train adapters with PEFT
|
||||||
|
|
||||||
|
We support tight integration with 🤗 PEFT library, allowing any user to conveniently train adapters and share them on the Hub, rather than training the entire model.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from datasets import load_dataset
|
||||||
|
from trl import RewardTrainer
|
||||||
|
from peft import LoraConfig
|
||||||
|
|
||||||
|
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
|
||||||
|
|
||||||
|
trainer = RewardTrainer(
|
||||||
|
"Qwen/Qwen3-4B",
|
||||||
train_dataset=dataset,
|
train_dataset=dataset,
|
||||||
peft_config=peft_config,
|
peft_config=LoraConfig(modules_to_save=["score"]) # important to include the score head when base model is not a sequence classification model
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Adding a margin to the loss
|
You can also continue training your [`~peft.PeftModel`]. For that, first load a `PeftModel` outside [`RewardTrainer`] and pass it directly to the trainer without the `peft_config` argument being passed.
|
||||||
|
|
||||||
As in the [Llama 2 paper](https://huggingface.co/papers/2307.09288), you can add a margin to the loss by adding a `margin` column to the dataset. The reward collator will automatically pass it through and the loss will be computed accordingly.
|
|
||||||
|
|
||||||
```python
|
```python
|
||||||
def add_margin(row):
|
from datasets import load_dataset
|
||||||
# Assume you have a score_chosen and score_rejected columns that you want to use to compute the margin
|
from trl import RewardTrainer
|
||||||
return {'margin': row['score_chosen'] - row['score_rejected']}
|
from peft import AutoPeftModelForCausalLM
|
||||||
|
|
||||||
dataset = dataset.map(add_margin)
|
model = AutoPeftModelForCausalLM.from_pretrained("trl-lib/Qwen3-4B-Reward-LoRA", is_trainable=True)
|
||||||
```
|
dataset = load_dataset("trl-lib/Capybara", split="train")
|
||||||
|
|
||||||
### Centering rewards
|
trainer = RewardTrainer(
|
||||||
|
model=model,
|
||||||
In many scenarios, it's preferable to ensure that a reward model's output is mean zero. This is often done by first calculating the model's average score and then subtracting it.
|
train_dataset=dataset,
|
||||||
|
|
||||||
[[Eisenstein et al., 2023]](https://huggingface.co/papers/2312.09244) proposed an auxiliary loss function designed to directly learn a centered reward model. This auxiliary loss minimizes the squared sum of the rewards, encouraging the model to naturally produce mean-zero outputs:
|
|
||||||
|
|
||||||
$$\Big( R(p, r_1) + R(p, r_2) \Big)^2 $$
|
|
||||||
|
|
||||||
This auxiliary loss is combined with the main loss function, weighted by the parameter `center_rewards_coefficient` in the `[RewardConfig]`. By default, this feature is deactivated (`center_rewards_coefficient = None`).
|
|
||||||
|
|
||||||
```python
|
|
||||||
training_args = RewardConfig(
|
|
||||||
center_rewards_coefficient=0.01,
|
|
||||||
...
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
trainer.train()
|
||||||
```
|
```
|
||||||
|
|
||||||
For reference results, please refer PR [#1932](https://github.com/huggingface/trl/pull/1932).
|
> [!TIP]
|
||||||
|
> When training adapters, you typically use a higher learning rate (≈1e‑3) since only new parameters are being learned.
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> RewardConfig(learning_rate=1e-3, ...)
|
||||||
|
> ```
|
||||||
|
|
||||||
|
## Tool Calling with Reward Modeling
|
||||||
|
|
||||||
|
The [`RewardTrainer`] fully supports fine-tuning models with _tool calling_ capabilities. In this case, each dataset example should include:
|
||||||
|
|
||||||
|
* The conversation messages, including any tool calls (`tool_calls`) and tool responses (`tool` role messages)
|
||||||
|
* The list of available tools in the `tools` column, typically provided as JSON schemas
|
||||||
|
|
||||||
|
For details on the expected dataset structure, see the [Dataset Format — Tool Calling](dataset_formats#tool-calling) section.
|
||||||
|
|
||||||
## RewardTrainer
|
## RewardTrainer
|
||||||
|
|
||||||
@ -91,3 +232,7 @@ For reference results, please refer PR [#1932](https://github.com/huggingface/tr
|
|||||||
## RewardConfig
|
## RewardConfig
|
||||||
|
|
||||||
[[autodoc]] RewardConfig
|
[[autodoc]] RewardConfig
|
||||||
|
|
||||||
|
## DataCollatoForPreference
|
||||||
|
|
||||||
|
[[autodoc]] trainer.reward_trainer.DataCollatorForPreference
|
||||||
|
@ -23,7 +23,7 @@ trainer = SFTTrainer(
|
|||||||
trainer.train()
|
trainer.train()
|
||||||
```
|
```
|
||||||
|
|
||||||
<iframe src="https://trl-lib-trackio.hf.space/?project=trl-documentation&metrics=train/loss,train/mean_token_accuracy,train/num_tokens&sidebar=hidden" style="width: 100%; min-width: 300px; max-width: 800px;" height="830" frameBorder="0"></iframe>
|
<iframe src="https://trl-lib-trackio.hf.space/?project=trl-documentation&metrics=train*&runs=sft_qwen3-0.6B_capybara" style="width: 100%; min-width: 300px; max-width: 800px;" height="830" frameBorder="0"></iframe>
|
||||||
|
|
||||||
## Expected dataset type and format
|
## Expected dataset type and format
|
||||||
|
|
||||||
|
@ -64,4 +64,4 @@ trainer.train()
|
|||||||
|
|
||||||
will give you a hosted dashboard at https://huggingface.co/spaces/trl-lib/trackio.
|
will give you a hosted dashboard at https://huggingface.co/spaces/trl-lib/trackio.
|
||||||
|
|
||||||
<iframe src="https://trl-lib-trackio.hf.space/?project=trl-documentation&sidebar=hidden" style="width: 100%; min-width: 300px; max-width: 800px;" height="830" frameBorder="0"></iframe>
|
<iframe src="https://trl-lib-trackio.hf.space/?project=trl-documentation&sidebar=hidden&runs=sft_qwen3-0.6B_capybara" style="width: 100%; min-width: 300px; max-width: 800px;" height="830" frameBorder="0"></iframe>
|
||||||
|
@ -43,6 +43,7 @@ from transformers import (
|
|||||||
GPT2LMHeadModel,
|
GPT2LMHeadModel,
|
||||||
GPTNeoXConfig,
|
GPTNeoXConfig,
|
||||||
GPTNeoXForCausalLM,
|
GPTNeoXForCausalLM,
|
||||||
|
GPTNeoXForSequenceClassification,
|
||||||
GptOssConfig,
|
GptOssConfig,
|
||||||
GptOssForCausalLM,
|
GptOssForCausalLM,
|
||||||
Idefics2Config,
|
Idefics2Config,
|
||||||
@ -73,6 +74,7 @@ from transformers import (
|
|||||||
Qwen3ForSequenceClassification,
|
Qwen3ForSequenceClassification,
|
||||||
Qwen3MoeConfig,
|
Qwen3MoeConfig,
|
||||||
Qwen3MoeForCausalLM,
|
Qwen3MoeForCausalLM,
|
||||||
|
Qwen3MoeForSequenceClassification,
|
||||||
SmolVLMForConditionalGeneration,
|
SmolVLMForConditionalGeneration,
|
||||||
T5ForConditionalGeneration,
|
T5ForConditionalGeneration,
|
||||||
)
|
)
|
||||||
@ -234,22 +236,46 @@ model = Qwen3ForCausalLM(config)
|
|||||||
push_to_hub(model, tokenizer, "small")
|
push_to_hub(model, tokenizer, "small")
|
||||||
|
|
||||||
# Reward models
|
# Reward models
|
||||||
for model_id, config_class, model_class, suffix in [
|
for model_id, model_class, suffix in [
|
||||||
("meta-llama/Llama-3.2-1B-Instruct", LlamaConfig, LlamaForSequenceClassification, "3.2"),
|
("EleutherAI/pythia-14m", GPTNeoXForSequenceClassification, None),
|
||||||
("Qwen/Qwen2.5-32B-Instruct", Qwen2Config, Qwen2ForSequenceClassification, "2.5"),
|
("meta-llama/Llama-3.2-1B-Instruct", LlamaForSequenceClassification, "3.2"),
|
||||||
("Qwen/Qwen3-4B", Qwen3Config, Qwen3ForSequenceClassification, None),
|
("Qwen/Qwen2.5-32B-Instruct", Qwen2ForSequenceClassification, "2.5"),
|
||||||
|
("Qwen/Qwen3-4B", Qwen3ForSequenceClassification, None),
|
||||||
]:
|
]:
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||||
config = config_class(
|
kwargs = {
|
||||||
vocab_size=len(tokenizer.vocab),
|
"num_labels": 1,
|
||||||
hidden_size=8,
|
"hidden_size": 16,
|
||||||
num_attention_heads=4,
|
"num_attention_heads": 4,
|
||||||
num_key_value_heads=2,
|
"num_key_value_heads": 2,
|
||||||
num_hidden_layers=2,
|
"num_hidden_layers": 2,
|
||||||
intermediate_size=32,
|
"intermediate_size": 32,
|
||||||
num_labels=1,
|
}
|
||||||
)
|
config = AutoConfig.from_pretrained(model_id, **kwargs)
|
||||||
model = model_class(config)
|
# Bug in transformers: it ignores num_hidden_layers to build layer_types
|
||||||
|
if model_id in ("Qwen/Qwen2.5-32B-Instruct", "Qwen/Qwen3-4B"):
|
||||||
|
config.layer_types = config.layer_types[:2]
|
||||||
|
model = model_class(config).to(dtype=torch.bfloat16)
|
||||||
|
init_weights_tiny_model(model)
|
||||||
|
push_to_hub(model, tokenizer, "tiny", suffix)
|
||||||
|
|
||||||
|
# MoE Reward models
|
||||||
|
for model_id, model_class, suffix in [
|
||||||
|
("Qwen/Qwen3-30B-A3B", Qwen3MoeForSequenceClassification, None),
|
||||||
|
]:
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||||
|
kwargs = {
|
||||||
|
"num_labels": 1,
|
||||||
|
"hidden_size": 16,
|
||||||
|
"num_attention_heads": 4,
|
||||||
|
"num_key_value_heads": 2,
|
||||||
|
"num_hidden_layers": 2,
|
||||||
|
"intermediate_size": 32,
|
||||||
|
"num_experts": 4,
|
||||||
|
"num_experts_per_tok": 2,
|
||||||
|
}
|
||||||
|
config = AutoConfig.from_pretrained(model_id, **kwargs)
|
||||||
|
model = model_class(config).to(dtype=torch.bfloat16)
|
||||||
push_to_hub(model, tokenizer, "tiny", suffix)
|
push_to_hub(model, tokenizer, "tiny", suffix)
|
||||||
|
|
||||||
|
|
||||||
@ -315,7 +341,5 @@ for model_id, model_class in [
|
|||||||
kwargs["perceiver_config"] = {"hidden_size": 16}
|
kwargs["perceiver_config"] = {"hidden_size": 16}
|
||||||
|
|
||||||
config = AutoConfig.from_pretrained(model_id, text_config=text_config, vision_config=vision_config, **kwargs)
|
config = AutoConfig.from_pretrained(model_id, text_config=text_config, vision_config=vision_config, **kwargs)
|
||||||
|
|
||||||
model = model_class(config).to(dtype=torch.bfloat16)
|
model = model_class(config).to(dtype=torch.bfloat16)
|
||||||
|
|
||||||
push_to_hub(model, processor, "tiny")
|
push_to_hub(model, processor, "tiny")
|
||||||
|
@ -67,6 +67,13 @@ class TestCLI(TrlTestCase):
|
|||||||
with patch("sys.argv", command.split(" ")):
|
with patch("sys.argv", command.split(" ")):
|
||||||
main()
|
main()
|
||||||
|
|
||||||
|
def test_reward(self):
|
||||||
|
from trl.cli import main
|
||||||
|
|
||||||
|
command = f"trl reward --output_dir {self.tmp_dir} --model_name_or_path trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5 --dataset_name trl-internal-testing/zen --dataset_config standard_implicit_prompt_preference --report_to none"
|
||||||
|
with patch("sys.argv", command.split(" ")):
|
||||||
|
main()
|
||||||
|
|
||||||
def test_rloo(self):
|
def test_rloo(self):
|
||||||
from trl.cli import main
|
from trl.cli import main
|
||||||
|
|
||||||
|
@ -12,217 +12,823 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import pathlib
|
||||||
|
import unittest
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from datasets import Dataset, load_dataset
|
from datasets import load_dataset
|
||||||
|
from parameterized import parameterized
|
||||||
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||||
from transformers.testing_utils import require_peft
|
from transformers.testing_utils import require_peft
|
||||||
from transformers.utils import is_peft_available
|
from transformers.utils import is_peft_available
|
||||||
|
|
||||||
from trl import RewardConfig, RewardTrainer, maybe_apply_chat_template
|
from trl import RewardConfig, RewardTrainer
|
||||||
from trl.trainer.reward_trainer import _tokenize
|
from trl.trainer.reward_trainer import DataCollatorForPreference
|
||||||
|
|
||||||
from .testing_utils import TrlTestCase
|
from .testing_utils import TrlTestCase
|
||||||
|
|
||||||
|
|
||||||
if is_peft_available():
|
if is_peft_available():
|
||||||
from peft import LoraConfig, TaskType
|
from peft import LoraConfig, PeftModel, get_peft_model
|
||||||
|
|
||||||
|
|
||||||
|
class TestDataCollatorForPreference(TrlTestCase):
|
||||||
|
def test_basic_padding(self):
|
||||||
|
"""Test basic padding functionality without completion masks."""
|
||||||
|
self.collator = DataCollatorForPreference(pad_token_id=0)
|
||||||
|
examples = [
|
||||||
|
{"chosen_input_ids": [1, 2, 3], "rejected_input_ids": [4, 5]},
|
||||||
|
{"chosen_input_ids": [6, 7], "rejected_input_ids": [8]},
|
||||||
|
]
|
||||||
|
|
||||||
|
result = self.collator(examples)
|
||||||
|
|
||||||
|
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [6, 7, 0], [4, 5, 0], [8, 0, 0]]))
|
||||||
|
torch.testing.assert_close(
|
||||||
|
result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0], [1, 1, 0], [1, 0, 0]])
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_pad_to_multiple_of(self):
|
||||||
|
"""Test padding to multiple of specified value."""
|
||||||
|
collator = DataCollatorForPreference(pad_token_id=0, pad_to_multiple_of=4)
|
||||||
|
examples = [
|
||||||
|
{"chosen_input_ids": [1, 2, 3], "rejected_input_ids": [4, 5]},
|
||||||
|
{"chosen_input_ids": [6, 7], "rejected_input_ids": [8]},
|
||||||
|
]
|
||||||
|
|
||||||
|
result = collator(examples)
|
||||||
|
|
||||||
|
torch.testing.assert_close(
|
||||||
|
result["input_ids"], torch.tensor([[1, 2, 3, 0], [6, 7, 0, 0], [4, 5, 0, 0], [8, 0, 0, 0]])
|
||||||
|
)
|
||||||
|
torch.testing.assert_close(
|
||||||
|
result["attention_mask"], torch.tensor([[1, 1, 1, 0], [1, 1, 0, 0], [1, 1, 0, 0], [1, 0, 0, 0]])
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_single_example(self):
|
||||||
|
"""Test collator with a single example."""
|
||||||
|
self.collator = DataCollatorForPreference(pad_token_id=0)
|
||||||
|
examples = [{"chosen_input_ids": [1, 2, 3], "rejected_input_ids": [4, 5]}]
|
||||||
|
|
||||||
|
result = self.collator(examples)
|
||||||
|
|
||||||
|
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 0]]))
|
||||||
|
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]]))
|
||||||
|
|
||||||
|
def test_different_pad_token_id(self):
|
||||||
|
"""Test with different pad token ID."""
|
||||||
|
collator = DataCollatorForPreference(pad_token_id=999)
|
||||||
|
examples = [
|
||||||
|
{"chosen_input_ids": [1, 2, 3], "rejected_input_ids": [4, 5]},
|
||||||
|
{"chosen_input_ids": [6, 7], "rejected_input_ids": [8]},
|
||||||
|
]
|
||||||
|
|
||||||
|
result = collator(examples)
|
||||||
|
|
||||||
|
torch.testing.assert_close(
|
||||||
|
result["input_ids"], torch.tensor([[1, 2, 3], [6, 7, 999], [4, 5, 999], [8, 999, 999]])
|
||||||
|
)
|
||||||
|
torch.testing.assert_close(
|
||||||
|
result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0], [1, 1, 0], [1, 0, 0]])
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_collate_with_margin(self):
|
||||||
|
self.collator = DataCollatorForPreference(pad_token_id=0)
|
||||||
|
examples = [
|
||||||
|
{"chosen_input_ids": [1, 2, 3], "rejected_input_ids": [4, 5], "margin": 0.1},
|
||||||
|
{"chosen_input_ids": [6, 7], "rejected_input_ids": [8], "margin": 0.2},
|
||||||
|
]
|
||||||
|
|
||||||
|
result = self.collator(examples)
|
||||||
|
|
||||||
|
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [6, 7, 0], [4, 5, 0], [8, 0, 0]]))
|
||||||
|
torch.testing.assert_close(
|
||||||
|
result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0], [1, 1, 0], [1, 0, 0]])
|
||||||
|
)
|
||||||
|
torch.testing.assert_close(result["margin"], torch.tensor([0.1, 0.2]))
|
||||||
|
|
||||||
|
|
||||||
class RewardTrainerTester(TrlTestCase):
|
class RewardTrainerTester(TrlTestCase):
|
||||||
def setUp(self):
|
@parameterized.expand(
|
||||||
super().setUp()
|
[
|
||||||
self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
|
("trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",),
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
|
("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification",),
|
||||||
self.model = AutoModelForSequenceClassification.from_pretrained(self.model_id)
|
("trl-internal-testing/tiny-LlamaForSequenceClassification-3.2",),
|
||||||
self.model.config.pad_token_id = self.tokenizer.pad_token_id
|
]
|
||||||
|
)
|
||||||
|
def test_train(self, model_id):
|
||||||
|
# Get the dataset
|
||||||
|
dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train")
|
||||||
|
|
||||||
def test_preprocessing_conversational(self):
|
# Initialize the trainer
|
||||||
dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_preference", split="train")
|
|
||||||
training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none")
|
training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none")
|
||||||
trainer = RewardTrainer(
|
trainer = RewardTrainer(model=model_id, args=training_args, train_dataset=dataset)
|
||||||
model=self.model, args=training_args, processing_class=self.tokenizer, train_dataset=dummy_dataset
|
|
||||||
)
|
|
||||||
dummy_dataset = dummy_dataset.map(maybe_apply_chat_template, fn_kwargs={"tokenizer": self.tokenizer})
|
|
||||||
dummy_dataset = dummy_dataset.map(_tokenize, batched=True, fn_kwargs={"tokenizer": self.tokenizer})
|
|
||||||
self.assertDictEqual(trainer.train_dataset[:], dummy_dataset[:])
|
|
||||||
|
|
||||||
def test_preprocessing_standard(self):
|
# Save the initial parameters to compare them later
|
||||||
# No chat template, so we load a fresh tokenizer
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(self.model_id)
|
|
||||||
dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train")
|
|
||||||
training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none")
|
|
||||||
trainer = RewardTrainer(
|
|
||||||
model=self.model, args=training_args, processing_class=tokenizer, train_dataset=dummy_dataset
|
|
||||||
)
|
|
||||||
dummy_dataset = dummy_dataset.map(_tokenize, batched=True, fn_kwargs={"tokenizer": tokenizer})
|
|
||||||
self.assertDictEqual(trainer.train_dataset[:], dummy_dataset[:])
|
|
||||||
|
|
||||||
def test_train_full(self):
|
|
||||||
dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_preference", split="train")
|
|
||||||
training_args = RewardConfig(output_dir=self.tmp_dir, max_steps=3, report_to="none")
|
|
||||||
trainer = RewardTrainer(
|
|
||||||
model=self.model, args=training_args, processing_class=self.tokenizer, train_dataset=dummy_dataset
|
|
||||||
)
|
|
||||||
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
||||||
|
|
||||||
|
# Train the model
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
|
# Check that the training loss is not None
|
||||||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||||
# Check that the parameters have changed
|
|
||||||
|
# Check the params have changed
|
||||||
for n, param in previous_trainable_params.items():
|
for n, param in previous_trainable_params.items():
|
||||||
new_param = trainer.model.get_parameter(n)
|
new_param = trainer.model.get_parameter(n)
|
||||||
if param.sum() != 0: # ignore 0 biases
|
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
|
||||||
self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12))
|
|
||||||
|
|
||||||
def test_train_full_pretokenized(self):
|
@parameterized.expand(
|
||||||
dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_preference", split="train")
|
[
|
||||||
dummy_dataset = dummy_dataset.map(maybe_apply_chat_template, fn_kwargs={"tokenizer": self.tokenizer})
|
("standard_preference",),
|
||||||
dummy_dataset = dummy_dataset.map(_tokenize, batched=True, fn_kwargs={"tokenizer": self.tokenizer})
|
("conversational_preference",),
|
||||||
training_args = RewardConfig(output_dir=self.tmp_dir, max_steps=3, report_to="none")
|
("standard_implicit_prompt_preference",),
|
||||||
|
("conversational_implicit_prompt_preference",),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def test_train_dataset_types(self, config_name):
|
||||||
|
# Get the dataset
|
||||||
|
dataset = load_dataset("trl-internal-testing/zen", config_name, split="train")
|
||||||
|
|
||||||
|
# Initialize the trainer
|
||||||
|
training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none")
|
||||||
trainer = RewardTrainer(
|
trainer = RewardTrainer(
|
||||||
model=self.model, args=training_args, processing_class=self.tokenizer, train_dataset=dummy_dataset
|
model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
|
||||||
|
args=training_args,
|
||||||
|
train_dataset=dataset,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Save the initial parameters to compare them later
|
||||||
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
||||||
|
|
||||||
|
# Train the model
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
|
# Check that the training loss is not None
|
||||||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||||
# Check that the parameters have changed
|
|
||||||
|
# Check the params have changed
|
||||||
for n, param in previous_trainable_params.items():
|
for n, param in previous_trainable_params.items():
|
||||||
new_param = trainer.model.get_parameter(n)
|
new_param = trainer.model.get_parameter(n)
|
||||||
if param.sum() != 0: # ignore 0 biases
|
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
|
||||||
self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12))
|
|
||||||
|
def test_train_model(self):
|
||||||
|
# Instantiate the model
|
||||||
|
model = AutoModelForSequenceClassification.from_pretrained(
|
||||||
|
"trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get the dataset
|
||||||
|
dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train")
|
||||||
|
|
||||||
|
# Initialize the trainer
|
||||||
|
training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none")
|
||||||
|
trainer = RewardTrainer(model=model, args=training_args, train_dataset=dataset)
|
||||||
|
|
||||||
|
# Save the initial parameters to compare them later
|
||||||
|
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
||||||
|
|
||||||
|
# Train the model
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
# Check that the training loss is not None
|
||||||
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||||
|
|
||||||
|
# Check the params have changed
|
||||||
|
for n, param in previous_trainable_params.items():
|
||||||
|
new_param = trainer.model.get_parameter(n)
|
||||||
|
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
|
||||||
|
|
||||||
|
def test_train_from_causal_lm(self):
|
||||||
|
# Get the dataset
|
||||||
|
dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train")
|
||||||
|
|
||||||
|
# Initialize the trainer
|
||||||
|
training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none")
|
||||||
|
trainer = RewardTrainer(
|
||||||
|
model="trl-internal-testing/tiny-Qwen3ForCausalLM", args=training_args, train_dataset=dataset
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save the initial parameters to compare them later
|
||||||
|
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
||||||
|
|
||||||
|
# Train the model
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
# Check that the training loss is not None
|
||||||
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||||
|
|
||||||
|
# Check the params have changed
|
||||||
|
for n, param in previous_trainable_params.items():
|
||||||
|
new_param = trainer.model.get_parameter(n)
|
||||||
|
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
|
||||||
|
|
||||||
|
def test_train_model_dtype(self):
|
||||||
|
# Get the dataset
|
||||||
|
dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train")
|
||||||
|
|
||||||
|
# Initialize the trainer
|
||||||
|
training_args = RewardConfig(
|
||||||
|
output_dir=self.tmp_dir,
|
||||||
|
model_init_kwargs={"dtype": torch.float16},
|
||||||
|
learning_rate=0.1,
|
||||||
|
report_to="none",
|
||||||
|
)
|
||||||
|
trainer = RewardTrainer(
|
||||||
|
model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
|
||||||
|
args=training_args,
|
||||||
|
train_dataset=dataset,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save the initial parameters to compare them later
|
||||||
|
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
||||||
|
|
||||||
|
# Train the model
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
# Check that the training loss is not None
|
||||||
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||||
|
|
||||||
|
# Check the params have changed
|
||||||
|
for n, param in previous_trainable_params.items():
|
||||||
|
# For some reasonn model.layers.0.input_layernorm.weight doesn't change in GitHub Actions but does
|
||||||
|
# locally. We ignore this parameter for now
|
||||||
|
if "layernorm" in n:
|
||||||
|
continue
|
||||||
|
new_param = trainer.model.get_parameter(n)
|
||||||
|
# Check the torch dtype
|
||||||
|
self.assertEqual(new_param.dtype, torch.float16)
|
||||||
|
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
|
||||||
|
|
||||||
@require_peft
|
@require_peft
|
||||||
def test_train_lora(self):
|
def test_train_dense_with_peft_config(self):
|
||||||
peft_config = LoraConfig(
|
# Get the base model parameter names
|
||||||
task_type=TaskType.SEQ_CLS,
|
model_id = "trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5"
|
||||||
inference_mode=False,
|
model = AutoModelForSequenceClassification.from_pretrained(model_id)
|
||||||
r=8,
|
base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()]
|
||||||
lora_alpha=32,
|
|
||||||
lora_dropout=0.1,
|
# Get the dataset
|
||||||
)
|
dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train")
|
||||||
dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_preference", split="train")
|
|
||||||
training_args = RewardConfig(output_dir=self.tmp_dir, max_steps=3, report_to="none")
|
# Initialize the trainer
|
||||||
|
training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none")
|
||||||
|
|
||||||
trainer = RewardTrainer(
|
trainer = RewardTrainer(
|
||||||
model=self.model,
|
model=model_id,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
processing_class=self.tokenizer,
|
train_dataset=dataset,
|
||||||
train_dataset=dummy_dataset,
|
peft_config=LoraConfig(),
|
||||||
peft_config=peft_config,
|
|
||||||
)
|
)
|
||||||
previous_trainable_params = {}
|
|
||||||
previous_non_trainable_params = {}
|
|
||||||
|
|
||||||
# due to a change in the way the modules to save are dealt in PEFT.
|
# Save the initial parameters to compare them later
|
||||||
trainable_params_name = ["lora", "modules_to_save"]
|
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
||||||
|
|
||||||
# check gradients are not None
|
|
||||||
for n, param in trainer.model.named_parameters():
|
|
||||||
if any(t in n for t in trainable_params_name):
|
|
||||||
previous_trainable_params[n] = param.clone()
|
|
||||||
else:
|
|
||||||
previous_non_trainable_params[n] = param.clone()
|
|
||||||
|
|
||||||
|
# Train the model
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
self.assertIsNotNone(trainer.state.log_history[(-1)]["train_loss"])
|
# Check that the training loss is not None
|
||||||
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||||
|
|
||||||
# Check that the parameters have changed
|
# Check the peft params have changed and the base model params have not changed
|
||||||
for n, param in previous_trainable_params.items():
|
for n, param in previous_trainable_params.items():
|
||||||
new_param = trainer.model.get_parameter(n)
|
new_param = trainer.model.get_parameter(n)
|
||||||
self.assertFalse(torch.allclose(param, new_param, atol=1e-12, rtol=1e-12))
|
if n in base_param_names: # We expect the base model parameters to be the same
|
||||||
|
self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed")
|
||||||
# Check that the non trainable parameters have not changed
|
elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer)
|
||||||
for n, param in previous_non_trainable_params.items():
|
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
|
||||||
new_param = trainer.model.get_parameter(n)
|
|
||||||
self.assertTrue(torch.allclose(param, new_param, atol=1e-12, rtol=1e-12))
|
|
||||||
|
|
||||||
@require_peft
|
@require_peft
|
||||||
def test_train_lora_pretokenized(self):
|
def test_train_moe_with_peft_config(self):
|
||||||
peft_config = LoraConfig(
|
# Get the base model parameter names
|
||||||
task_type=TaskType.SEQ_CLS,
|
model_id = "trl-internal-testing/tiny-Qwen3MoeForSequenceClassification"
|
||||||
inference_mode=False,
|
model = AutoModelForSequenceClassification.from_pretrained(model_id)
|
||||||
r=8,
|
base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()]
|
||||||
lora_alpha=32,
|
|
||||||
lora_dropout=0.1,
|
# Get the dataset
|
||||||
)
|
dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train")
|
||||||
dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_preference", split="train")
|
|
||||||
dummy_dataset = dummy_dataset.map(maybe_apply_chat_template, fn_kwargs={"tokenizer": self.tokenizer})
|
# Initialize the trainer
|
||||||
dummy_dataset = dummy_dataset.map(_tokenize, batched=True, fn_kwargs={"tokenizer": self.tokenizer})
|
training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none")
|
||||||
training_args = RewardConfig(output_dir=self.tmp_dir, max_steps=3, report_to="none")
|
|
||||||
trainer = RewardTrainer(
|
trainer = RewardTrainer(
|
||||||
model=self.model,
|
model=model_id,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
processing_class=self.tokenizer,
|
train_dataset=dataset,
|
||||||
train_dataset=dummy_dataset,
|
peft_config=LoraConfig(target_modules=["up_proj", "down_proj", "score"]),
|
||||||
peft_config=peft_config,
|
|
||||||
)
|
)
|
||||||
previous_trainable_params = {}
|
|
||||||
previous_non_trainable_params = {}
|
|
||||||
|
|
||||||
# due to a change in the way the modules to save are dealt in PEFT.
|
# Save the initial parameters to compare them later
|
||||||
trainable_params_name = ["lora", "modules_to_save"]
|
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
||||||
|
|
||||||
# check gradients are not None
|
|
||||||
for n, param in trainer.model.named_parameters():
|
|
||||||
if any(t in n for t in trainable_params_name):
|
|
||||||
previous_trainable_params[n] = param.clone()
|
|
||||||
else:
|
|
||||||
previous_non_trainable_params[n] = param.clone()
|
|
||||||
|
|
||||||
|
# Train the model
|
||||||
trainer.train()
|
trainer.train()
|
||||||
|
|
||||||
self.assertIsNotNone(trainer.state.log_history[(-1)]["train_loss"])
|
# Check that the training loss is not None
|
||||||
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||||
|
|
||||||
# Check that the parameters have changed
|
# Check the peft params have changed and the base model params have not changed
|
||||||
for n, param in previous_trainable_params.items():
|
for n, param in previous_trainable_params.items():
|
||||||
new_param = trainer.model.get_parameter(n)
|
new_param = trainer.model.get_parameter(n)
|
||||||
self.assertFalse(torch.allclose(param, new_param, atol=1e-12, rtol=1e-12))
|
if n in base_param_names: # We expect the base model parameters to be the same
|
||||||
|
self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed")
|
||||||
|
elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer)
|
||||||
|
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
|
||||||
|
|
||||||
# Check that the non trainable parameters have not changed
|
@require_peft
|
||||||
for n, param in previous_non_trainable_params.items():
|
def test_train_peft_model(self):
|
||||||
|
# Get the base model
|
||||||
|
model_id = "trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5"
|
||||||
|
model = AutoModelForSequenceClassification.from_pretrained(model_id)
|
||||||
|
|
||||||
|
# Get the base model parameter names
|
||||||
|
base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()]
|
||||||
|
|
||||||
|
# Turn the model into a peft model
|
||||||
|
lora_config = LoraConfig()
|
||||||
|
model = get_peft_model(model, lora_config)
|
||||||
|
|
||||||
|
# Get the dataset
|
||||||
|
dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train")
|
||||||
|
|
||||||
|
# Initialize the trainer
|
||||||
|
training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none")
|
||||||
|
trainer = RewardTrainer(model=model, args=training_args, train_dataset=dataset)
|
||||||
|
|
||||||
|
# Save the initial parameters to compare them later
|
||||||
|
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
||||||
|
|
||||||
|
# Train the model
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
# Check that the training loss is not None
|
||||||
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||||
|
|
||||||
|
# Check the peft params have changed and the base model params have not changed
|
||||||
|
for n, param in previous_trainable_params.items():
|
||||||
new_param = trainer.model.get_parameter(n)
|
new_param = trainer.model.get_parameter(n)
|
||||||
self.assertTrue(torch.allclose(param, new_param, atol=1e-12, rtol=1e-12))
|
if n in base_param_names: # We expect the base model parameters to be the same
|
||||||
|
self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed")
|
||||||
|
elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer)
|
||||||
|
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
|
||||||
|
|
||||||
|
@require_peft
|
||||||
|
def test_train_dense_with_peft_config_and_gradient_checkpointing(self):
|
||||||
|
# Get the base model parameter names
|
||||||
|
model_id = "trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5"
|
||||||
|
model = AutoModelForSequenceClassification.from_pretrained(model_id)
|
||||||
|
base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()]
|
||||||
|
|
||||||
|
# Get the dataset
|
||||||
|
dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train")
|
||||||
|
|
||||||
|
# Initialize the trainer
|
||||||
|
training_args = RewardConfig(output_dir=self.tmp_dir, gradient_checkpointing=True, report_to="none")
|
||||||
|
|
||||||
def test_margin(self):
|
|
||||||
dummy_dataset_dict = {
|
|
||||||
"input_ids_chosen": [
|
|
||||||
torch.LongTensor([0, 1, 2]),
|
|
||||||
],
|
|
||||||
"attention_mask_chosen": [
|
|
||||||
torch.LongTensor([1, 1, 1]),
|
|
||||||
],
|
|
||||||
"input_ids_rejected": [
|
|
||||||
torch.LongTensor([0, 2]),
|
|
||||||
],
|
|
||||||
"attention_mask_rejected": [
|
|
||||||
torch.LongTensor([1, 1]),
|
|
||||||
],
|
|
||||||
"margin": [
|
|
||||||
torch.FloatTensor([1.0]),
|
|
||||||
],
|
|
||||||
}
|
|
||||||
dummy_dataset = Dataset.from_dict(dummy_dataset_dict)
|
|
||||||
training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none")
|
|
||||||
trainer = RewardTrainer(
|
trainer = RewardTrainer(
|
||||||
model=self.model, args=training_args, processing_class=self.tokenizer, train_dataset=dummy_dataset
|
model=model_id,
|
||||||
|
args=training_args,
|
||||||
|
train_dataset=dataset,
|
||||||
|
peft_config=LoraConfig(),
|
||||||
)
|
)
|
||||||
|
|
||||||
batch = [dummy_dataset[0]]
|
# Save the initial parameters to compare them later
|
||||||
batch = trainer.data_collator(batch)
|
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
||||||
batch = {k: v.to(trainer.model.device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
|
|
||||||
loss, outputs = trainer.compute_loss(trainer.model, batch, return_outputs=True)
|
|
||||||
|
|
||||||
l_val = -torch.nn.functional.logsigmoid(
|
# Train the model
|
||||||
outputs["rewards_chosen"] - outputs["rewards_rejected"] - batch["margin"]
|
trainer.train()
|
||||||
).mean()
|
|
||||||
|
|
||||||
self.assertLess(abs(loss - l_val), 1e-6)
|
# Check that the training loss is not None
|
||||||
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||||
|
|
||||||
def test_tags(self):
|
# Check the peft params have changed and the base model params have not changed
|
||||||
dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_preference", split="train")
|
for n, param in previous_trainable_params.items():
|
||||||
|
new_param = trainer.model.get_parameter(n)
|
||||||
|
if n in base_param_names: # We expect the base model parameters to be the same
|
||||||
|
self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed")
|
||||||
|
elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer)
|
||||||
|
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
|
||||||
|
|
||||||
|
@require_peft
|
||||||
|
def test_train_moe_with_peft_config_and_gradient_checkpointing(self):
|
||||||
|
# Get the base model parameter names
|
||||||
|
model_id = "trl-internal-testing/tiny-Qwen3MoeForSequenceClassification"
|
||||||
|
model = AutoModelForSequenceClassification.from_pretrained(model_id)
|
||||||
|
base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()]
|
||||||
|
|
||||||
|
# Get the dataset
|
||||||
|
dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train")
|
||||||
|
|
||||||
|
# Initialize the trainer
|
||||||
|
training_args = RewardConfig(output_dir=self.tmp_dir, gradient_checkpointing=True, report_to="none")
|
||||||
|
|
||||||
|
trainer = RewardTrainer(
|
||||||
|
model=model_id,
|
||||||
|
args=training_args,
|
||||||
|
train_dataset=dataset,
|
||||||
|
peft_config=LoraConfig(target_modules=["up_proj", "down_proj", "score"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save the initial parameters to compare them later
|
||||||
|
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
||||||
|
|
||||||
|
# Train the model
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
# Check that the training loss is not None
|
||||||
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||||
|
|
||||||
|
# Check the peft params have changed and the base model params have not changed
|
||||||
|
for n, param in previous_trainable_params.items():
|
||||||
|
new_param = trainer.model.get_parameter(n)
|
||||||
|
if n in base_param_names: # We expect the base model parameters to be the same
|
||||||
|
self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed")
|
||||||
|
elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer)
|
||||||
|
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
|
||||||
|
|
||||||
|
@require_peft
|
||||||
|
def test_train_with_peft_model_and_gradient_checkpointing(self):
|
||||||
|
# Get the base model parameter names
|
||||||
|
model_id = "trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5"
|
||||||
|
model = AutoModelForSequenceClassification.from_pretrained(model_id)
|
||||||
|
base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()]
|
||||||
|
model = get_peft_model(model, LoraConfig())
|
||||||
|
|
||||||
|
# Get the dataset
|
||||||
|
dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train")
|
||||||
|
|
||||||
|
# Initialize the trainer
|
||||||
|
training_args = RewardConfig(output_dir=self.tmp_dir, gradient_checkpointing=True, report_to="none")
|
||||||
|
|
||||||
|
trainer = RewardTrainer(model=model, args=training_args, train_dataset=dataset)
|
||||||
|
|
||||||
|
# Verify model is a PeftModel
|
||||||
|
self.assertIsInstance(trainer.model, PeftModel)
|
||||||
|
|
||||||
|
# Save the initial parameters to compare them later
|
||||||
|
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
||||||
|
|
||||||
|
# Train the model
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
# Check that the training loss is not None
|
||||||
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||||
|
|
||||||
|
# Check the peft params have changed and the base model params have not changed
|
||||||
|
for n, param in previous_trainable_params.items():
|
||||||
|
new_param = trainer.model.get_parameter(n)
|
||||||
|
if n in base_param_names: # We expect the base model parameters to be the same
|
||||||
|
self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed")
|
||||||
|
elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer)
|
||||||
|
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
|
||||||
|
|
||||||
|
def test_train_with_pretokenized_data(self):
|
||||||
|
# Get the dataset
|
||||||
|
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||||
|
dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train")
|
||||||
|
|
||||||
|
def tokenize_example(example):
|
||||||
|
return {
|
||||||
|
"chosen_input_ids": tokenizer(example["chosen"]).input_ids,
|
||||||
|
"rejected_input_ids": tokenizer(example["rejected"]).input_ids,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Apply tokenization
|
||||||
|
tokenized_dataset = dataset.map(tokenize_example, remove_columns=["chosen", "rejected"])
|
||||||
|
|
||||||
|
# Initialize the trainer
|
||||||
|
training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none")
|
||||||
|
trainer = RewardTrainer(model=model_id, args=training_args, train_dataset=tokenized_dataset)
|
||||||
|
|
||||||
|
# Save the initial parameters to compare them later
|
||||||
|
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
||||||
|
|
||||||
|
# Train the model
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
# Check that the training loss is not None
|
||||||
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||||
|
|
||||||
|
# Check the params have changed
|
||||||
|
for n, param in previous_trainable_params.items():
|
||||||
|
new_param = trainer.model.get_parameter(n)
|
||||||
|
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
|
||||||
|
|
||||||
|
def test_train_with_iterable_dataset(self):
|
||||||
|
# Get the dataset
|
||||||
|
dataset = load_dataset(
|
||||||
|
"trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train", streaming=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize the trainer
|
||||||
|
training_args = RewardConfig(output_dir=self.tmp_dir, max_steps=3, report_to="none")
|
||||||
|
trainer = RewardTrainer(
|
||||||
|
model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
|
||||||
|
args=training_args,
|
||||||
|
train_dataset=dataset,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save the initial parameters to compare them later
|
||||||
|
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
||||||
|
|
||||||
|
# Train the model
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
# Check that the training loss is not None
|
||||||
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||||
|
|
||||||
|
# Check the params have changed
|
||||||
|
for n, param in previous_trainable_params.items():
|
||||||
|
new_param = trainer.model.get_parameter(n)
|
||||||
|
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
|
||||||
|
|
||||||
|
def test_train_with_chat_template_kwargs(self):
|
||||||
|
# Get the dataset
|
||||||
|
dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train")
|
||||||
|
|
||||||
|
# Initialize the trainer
|
||||||
|
training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none")
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5")
|
||||||
|
# The following template is a simplified version of the Qwen chat template, where an additional argument
|
||||||
|
# `role_capital` is used to control the capitalization of roles.
|
||||||
|
tokenizer.chat_template = '{%- if messages[0]["role"] == "system" -%} {{ "<|im_start|>" + ("SYSTEM" if role_capital else "system") + "\\n" + messages[0]["content"] + "<|im_end|>\\n" }}{%- else -%} {{ "<|im_start|>" + ("SYSTEM" if role_capital else "system") + "\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n" }}{%- endif -%}{%- for message in messages -%} {%- if (message.role == "user") or (message.role == "system" and not loop.first) or (message.role == "assistant" and not message.tool_calls) -%} {{ "<|im_start|>" + (message.role.upper() if role_capital else message.role) + "\\n" + message.content + "<|im_end|>\\n" }} {%- elif message.role == "assistant" -%} {{ "<|im_start|>" + ("ASSISTANT" if role_capital else "assistant") }} {%- if message.content -%} {{ "\\n" + message.content }} {%- endif -%} {{ "<|im_end|>\\n" }} {%- elif message.role == "tool" -%} {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") -%} {{ "<|im_start|>" + ("USER" if role_capital else "user") }} {%- endif -%} {{ "\\n<tool_response>\\n" + message.content + "\\n</tool_response>" }} {%- if loop.last or (messages[loop.index0 + 1].role != "tool") -%} {{ "<|im_end|>\\n" }} {%- endif -%} {%- endif -%}{%- endfor -%}{%- if add_generation_prompt -%} {{ "<|im_start|>" + ("ASSISTANT" if role_capital else "assistant") + "\\n" }}{%- endif -%}'
|
||||||
|
|
||||||
|
dataset.add_column("chat_template_kwargs", [{"role_capital": bool(i % 2)} for i in range(len(dataset))])
|
||||||
|
|
||||||
|
trainer = RewardTrainer(
|
||||||
|
model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
|
||||||
|
args=training_args,
|
||||||
|
train_dataset=dataset,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save the initial parameters to compare them later
|
||||||
|
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
||||||
|
|
||||||
|
# Train the model
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
# Check that the training loss is not None
|
||||||
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||||
|
|
||||||
|
# Check the params have changed
|
||||||
|
for n, param in previous_trainable_params.items():
|
||||||
|
new_param = trainer.model.get_parameter(n)
|
||||||
|
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
|
||||||
|
|
||||||
|
def test_train_with_set_chat_template_from_model(self):
|
||||||
|
# Get the dataset
|
||||||
|
dataset = load_dataset("trl-internal-testing/zen", "conversational_preference", split="train")
|
||||||
|
|
||||||
|
# Initialize the trainer
|
||||||
|
training_args = RewardConfig(output_dir=self.tmp_dir, chat_template_path="Qwen/Qwen3-4B", report_to="none")
|
||||||
|
# trl-internal-testing/tiny-GPTNeoXForSequenceClassification doesn't have a chat template set by default
|
||||||
|
trainer = RewardTrainer(
|
||||||
|
model="trl-internal-testing/tiny-GPTNeoXForSequenceClassification",
|
||||||
|
args=training_args,
|
||||||
|
train_dataset=dataset,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save the initial parameters to compare them later
|
||||||
|
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
||||||
|
|
||||||
|
# Train the model
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
# Check that the training loss is not None
|
||||||
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||||
|
|
||||||
|
# Check the params have changed
|
||||||
|
for n, param in previous_trainable_params.items():
|
||||||
|
new_param = trainer.model.get_parameter(n)
|
||||||
|
# RewardTrainer uses a mean-free loss that cancels uniform shifts in output scores. Since GPT-NeoX models
|
||||||
|
# include a final LayerNorm, its bias consistently receives zero gradient and remains unchanged, so we skip
|
||||||
|
# this parameter.
|
||||||
|
if n == "gpt_neox.final_layer_norm.bias":
|
||||||
|
continue
|
||||||
|
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
|
||||||
|
|
||||||
|
def test_train_with_set_chat_template_from_path(self):
|
||||||
|
# Get the dataset
|
||||||
|
dataset = load_dataset("trl-internal-testing/zen", "conversational_preference", split="train")
|
||||||
|
|
||||||
|
# Initialize the trainer
|
||||||
|
training_args = RewardConfig(
|
||||||
|
output_dir=self.tmp_dir,
|
||||||
|
chat_template_path=str(pathlib.Path(__file__).parent / "data" / "template.jinja"),
|
||||||
|
report_to="none",
|
||||||
|
)
|
||||||
|
# trl-internal-testing/tiny-GPTNeoXForSequenceClassification doesn't have a chat template set by default
|
||||||
|
trainer = RewardTrainer(
|
||||||
|
model="trl-internal-testing/tiny-GPTNeoXForSequenceClassification",
|
||||||
|
args=training_args,
|
||||||
|
train_dataset=dataset,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save the initial parameters to compare them later
|
||||||
|
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
||||||
|
|
||||||
|
# Train the model
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
# Check that the training loss is not None
|
||||||
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||||
|
|
||||||
|
# Check the params have changed
|
||||||
|
for n, param in previous_trainable_params.items():
|
||||||
|
new_param = trainer.model.get_parameter(n)
|
||||||
|
# RewardTrainer uses a mean-free loss that cancels uniform shifts in output scores. Since GPT-NeoX models
|
||||||
|
# include a final LayerNorm, its bias consistently receives zero gradient and remains unchanged, so we skip
|
||||||
|
# this parameter.
|
||||||
|
if n == "gpt_neox.final_layer_norm.bias":
|
||||||
|
continue
|
||||||
|
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
|
||||||
|
|
||||||
|
# Check that the template saved in the output directory is the same as the one used for training
|
||||||
|
template_path = pathlib.Path(self.tmp_dir) / "checkpoint-9" / "chat_template.jinja"
|
||||||
|
self.assertTrue(template_path.exists(), f"Chat template not found at {template_path}")
|
||||||
|
|
||||||
|
with open(template_path) as f:
|
||||||
|
template_content = f.read()
|
||||||
|
with open(training_args.chat_template_path) as f:
|
||||||
|
original_template_content = f.read()
|
||||||
|
self.assertEqual(
|
||||||
|
template_content, original_template_content, "Chat template content does not match the original"
|
||||||
|
)
|
||||||
|
|
||||||
|
@unittest.skip("Skipping until we have a dataset with tool calls")
|
||||||
|
def test_train_toolcall_data(self):
|
||||||
|
# Get the dataset
|
||||||
|
dataset = load_dataset("trl-internal-testing/toolcall", split="train")
|
||||||
|
|
||||||
|
# Initialize the trainer
|
||||||
training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none")
|
training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none")
|
||||||
trainer = RewardTrainer(
|
trainer = RewardTrainer(
|
||||||
model=self.model, args=training_args, processing_class=self.tokenizer, train_dataset=dummy_dataset
|
model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
|
||||||
|
args=training_args,
|
||||||
|
train_dataset=dataset,
|
||||||
)
|
)
|
||||||
self.assertEqual(trainer.model.model_tags, trainer._tag_names)
|
|
||||||
|
# Save the initial parameters to compare them later
|
||||||
|
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
||||||
|
|
||||||
|
# Train the model
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
# Check that the training loss is not None
|
||||||
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||||
|
|
||||||
|
# Check the params have changed
|
||||||
|
for n, param in previous_trainable_params.items():
|
||||||
|
new_param = trainer.model.get_parameter(n)
|
||||||
|
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
|
||||||
|
|
||||||
|
def test_train_with_eval(self):
|
||||||
|
# Get the dataset
|
||||||
|
dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference")
|
||||||
|
|
||||||
|
# Initialize the trainer
|
||||||
|
training_args = RewardConfig(output_dir=self.tmp_dir, eval_strategy="steps", eval_steps=3, report_to="none")
|
||||||
|
trainer = RewardTrainer(
|
||||||
|
model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
|
||||||
|
args=training_args,
|
||||||
|
train_dataset=dataset["train"],
|
||||||
|
eval_dataset=dataset["test"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Train the model
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
# Check that the eval loss is not None
|
||||||
|
self.assertIsNotNone(trainer.state.log_history[0]["eval_loss"])
|
||||||
|
|
||||||
|
def test_train_with_multiple_eval_dataset(self):
|
||||||
|
# Get the dataset
|
||||||
|
dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference")
|
||||||
|
|
||||||
|
# Initialize the trainer
|
||||||
|
training_args = RewardConfig(output_dir=self.tmp_dir, eval_strategy="steps", eval_steps=3, report_to="none")
|
||||||
|
trainer = RewardTrainer(
|
||||||
|
model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
|
||||||
|
args=training_args,
|
||||||
|
train_dataset=dataset["train"],
|
||||||
|
eval_dataset={"data1": dataset["test"], "data2": dataset["test"]},
|
||||||
|
)
|
||||||
|
# Train the model
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
# Check that the eval losses are not None
|
||||||
|
self.assertIsNotNone(trainer.state.log_history[-3]["eval_data1_loss"])
|
||||||
|
self.assertIsNotNone(trainer.state.log_history[-2]["eval_data2_loss"])
|
||||||
|
|
||||||
|
def test_train_with_gradient_checkpointing(self):
|
||||||
|
# Get the dataset
|
||||||
|
dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train")
|
||||||
|
|
||||||
|
# Initialize the trainer
|
||||||
|
training_args = RewardConfig(output_dir=self.tmp_dir, gradient_checkpointing=True, report_to="none")
|
||||||
|
trainer = RewardTrainer(
|
||||||
|
model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
|
||||||
|
args=training_args,
|
||||||
|
train_dataset=dataset,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save the initial parameters to compare them later
|
||||||
|
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
||||||
|
|
||||||
|
# Train the model
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
# Check that the training loss is not None
|
||||||
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||||
|
|
||||||
|
# Check the params have changed
|
||||||
|
for n, param in previous_trainable_params.items():
|
||||||
|
new_param = trainer.model.get_parameter(n)
|
||||||
|
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
|
||||||
|
|
||||||
|
def test_tag_added(self):
|
||||||
|
# Get the dataset
|
||||||
|
dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train")
|
||||||
|
|
||||||
|
# Initialize the trainer
|
||||||
|
trainer = RewardTrainer(
|
||||||
|
model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
|
||||||
|
train_dataset=dataset,
|
||||||
|
)
|
||||||
|
|
||||||
|
for tag in ["reward-trainer", "trl"]:
|
||||||
|
self.assertIn(tag, trainer.model.model_tags)
|
||||||
|
|
||||||
|
@require_peft
|
||||||
|
def test_tag_added_peft(self):
|
||||||
|
# Get the dataset
|
||||||
|
dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train")
|
||||||
|
|
||||||
|
# Initialize the trainer
|
||||||
|
trainer = RewardTrainer(
|
||||||
|
model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
|
||||||
|
train_dataset=dataset,
|
||||||
|
peft_config=LoraConfig(),
|
||||||
|
)
|
||||||
|
|
||||||
|
for tag in ["reward-trainer", "trl"]:
|
||||||
|
self.assertIn(tag, trainer.model.model_tags)
|
||||||
|
|
||||||
|
def test_train_with_margin(self):
|
||||||
|
# Get the dataset
|
||||||
|
dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train")
|
||||||
|
|
||||||
|
def add_margin(example):
|
||||||
|
# dummy margin based on the length of the chosen summary
|
||||||
|
return {"margin": len(example["chosen"])}
|
||||||
|
|
||||||
|
dataset = dataset.map(add_margin)
|
||||||
|
|
||||||
|
# Initialize the trainer
|
||||||
|
training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none")
|
||||||
|
trainer = RewardTrainer(
|
||||||
|
model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
|
||||||
|
args=training_args,
|
||||||
|
train_dataset=dataset,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save the initial parameters to compare them later
|
||||||
|
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
||||||
|
|
||||||
|
# Train the model
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
# Check that the training loss is not None
|
||||||
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||||
|
|
||||||
|
# Check the params have changed
|
||||||
|
for n, param in previous_trainable_params.items():
|
||||||
|
new_param = trainer.model.get_parameter(n)
|
||||||
|
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
|
||||||
|
|
||||||
|
def test_train_with_center_rewards_coefficient(self):
|
||||||
|
# Get the dataset
|
||||||
|
dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train")
|
||||||
|
|
||||||
|
# Initialize the trainer
|
||||||
|
training_args = RewardConfig(output_dir=self.tmp_dir, center_rewards_coefficient=0.01, report_to="none")
|
||||||
|
trainer = RewardTrainer(
|
||||||
|
model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
|
||||||
|
args=training_args,
|
||||||
|
train_dataset=dataset,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save the initial parameters to compare them later
|
||||||
|
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
||||||
|
|
||||||
|
# Train the model
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
# Check that the training loss is not None
|
||||||
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||||
|
|
||||||
|
# Check the params have changed
|
||||||
|
for n, param in previous_trainable_params.items():
|
||||||
|
new_param = trainer.model.get_parameter(n)
|
||||||
|
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
|
||||||
|
11
trl/cli.py
11
trl/cli.py
@ -24,6 +24,7 @@ from .scripts.dpo import make_parser as make_dpo_parser
|
|||||||
from .scripts.env import print_env
|
from .scripts.env import print_env
|
||||||
from .scripts.grpo import make_parser as make_grpo_parser
|
from .scripts.grpo import make_parser as make_grpo_parser
|
||||||
from .scripts.kto import make_parser as make_kto_parser
|
from .scripts.kto import make_parser as make_kto_parser
|
||||||
|
from .scripts.reward import make_parser as make_reward_parser
|
||||||
from .scripts.rloo import make_parser as make_rloo_parser
|
from .scripts.rloo import make_parser as make_rloo_parser
|
||||||
from .scripts.sft import make_parser as make_sft_parser
|
from .scripts.sft import make_parser as make_sft_parser
|
||||||
from .scripts.utils import TrlParser
|
from .scripts.utils import TrlParser
|
||||||
@ -45,6 +46,7 @@ def main():
|
|||||||
subparsers.add_parser("env", help="Print the environment information")
|
subparsers.add_parser("env", help="Print the environment information")
|
||||||
make_grpo_parser(subparsers)
|
make_grpo_parser(subparsers)
|
||||||
make_kto_parser(subparsers)
|
make_kto_parser(subparsers)
|
||||||
|
make_reward_parser(subparsers)
|
||||||
make_rloo_parser(subparsers)
|
make_rloo_parser(subparsers)
|
||||||
make_sft_parser(subparsers)
|
make_sft_parser(subparsers)
|
||||||
make_vllm_serve_parser(subparsers)
|
make_vllm_serve_parser(subparsers)
|
||||||
@ -111,6 +113,15 @@ def main():
|
|||||||
args.training_script_args = sys.argv[2:] # remove "trl" and "kto"
|
args.training_script_args = sys.argv[2:] # remove "trl" and "kto"
|
||||||
launch_command(args) # launch training
|
launch_command(args) # launch training
|
||||||
|
|
||||||
|
elif args.command == "reward":
|
||||||
|
# Get the default args for the launch command
|
||||||
|
reward_training_script = resources.files("trl.scripts").joinpath("reward.py")
|
||||||
|
args = launch_command_parser().parse_args([str(reward_training_script)])
|
||||||
|
|
||||||
|
# Feed the args to the launch command
|
||||||
|
args.training_script_args = sys.argv[2:] # remove "trl" and "reward"
|
||||||
|
launch_command(args) # launch training
|
||||||
|
|
||||||
elif args.command == "rloo":
|
elif args.command == "rloo":
|
||||||
# Get the default args for the launch command
|
# Get the default args for the launch command
|
||||||
rloo_training_script = resources.files("trl.scripts").joinpath("rloo.py")
|
rloo_training_script = resources.files("trl.scripts").joinpath("rloo.py")
|
||||||
|
@ -94,11 +94,8 @@ def setup_chat_format(
|
|||||||
Setup chat format by adding special tokens to the tokenizer, setting the correct format, and extending the
|
Setup chat format by adding special tokens to the tokenizer, setting the correct format, and extending the
|
||||||
embedding layer of the model based on the new special tokens.
|
embedding layer of the model based on the new special tokens.
|
||||||
|
|
||||||
<Tip warning="true">
|
> [!WARNING]
|
||||||
|
> This function is deprecated and will be removed in version 0.26.0. Please use [`clone_chat_template`] instead.
|
||||||
This function is deprecated and will be removed in version 0.26.0. Please use [`clone_chat_template`] instead.
|
|
||||||
|
|
||||||
</Tip>
|
|
||||||
|
|
||||||
If the model already has a chat template, this will throw an error. If you want to overwrite it, please set
|
If the model already has a chat template, this will throw an error. If you want to overwrite it, please set
|
||||||
`tokenizer.chat_template` to `None`.
|
`tokenizer.chat_template` to `None`.
|
||||||
|
109
trl/scripts/reward.py
Normal file
109
trl/scripts/reward.py
Normal file
@ -0,0 +1,109 @@
|
|||||||
|
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# /// script
|
||||||
|
# dependencies = [
|
||||||
|
# "trl",
|
||||||
|
# "peft",
|
||||||
|
# "trackio",
|
||||||
|
# "kernels",
|
||||||
|
# ]
|
||||||
|
# ///
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from accelerate import logging
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
from trl import (
|
||||||
|
DatasetMixtureConfig,
|
||||||
|
ModelConfig,
|
||||||
|
RewardConfig,
|
||||||
|
RewardTrainer,
|
||||||
|
ScriptArguments,
|
||||||
|
TrlParser,
|
||||||
|
get_dataset,
|
||||||
|
get_peft_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
# Enable logging in a Hugging Face Space
|
||||||
|
os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio")
|
||||||
|
|
||||||
|
|
||||||
|
def main(script_args, training_args, model_args, dataset_args):
|
||||||
|
# Load the dataset
|
||||||
|
if dataset_args.datasets and script_args.dataset_name:
|
||||||
|
logger.warning(
|
||||||
|
"Both `datasets` and `dataset_name` are provided. The `datasets` argument will be used to load the "
|
||||||
|
"dataset and `dataset_name` will be ignored."
|
||||||
|
)
|
||||||
|
dataset = get_dataset(dataset_args)
|
||||||
|
elif dataset_args.datasets and not script_args.dataset_name:
|
||||||
|
dataset = get_dataset(dataset_args)
|
||||||
|
elif not dataset_args.datasets and script_args.dataset_name:
|
||||||
|
dataset = load_dataset(
|
||||||
|
script_args.dataset_name, name=script_args.dataset_config, streaming=script_args.dataset_streaming
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("Either `datasets` or `dataset_name` must be provided.")
|
||||||
|
|
||||||
|
# Initialize the RewardTrainer
|
||||||
|
trainer = RewardTrainer(
|
||||||
|
model=model_args.model_name_or_path,
|
||||||
|
args=training_args,
|
||||||
|
train_dataset=dataset[script_args.dataset_train_split],
|
||||||
|
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
|
||||||
|
peft_config=get_peft_config(model_args),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Train the model
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
# Log training complete
|
||||||
|
trainer.accelerator.print("✅ Training completed.")
|
||||||
|
|
||||||
|
# Save and push to Hub
|
||||||
|
trainer.save_model(training_args.output_dir)
|
||||||
|
trainer.accelerator.print(f"💾 Model saved to {training_args.output_dir}.")
|
||||||
|
|
||||||
|
if training_args.push_to_hub:
|
||||||
|
trainer.push_to_hub(dataset_name=script_args.dataset_name)
|
||||||
|
trainer.accelerator.print(f"🤗 Model pushed to the Hub in https://huggingface.co/{trainer.hub_model_id}.")
|
||||||
|
|
||||||
|
|
||||||
|
def make_parser(subparsers: Optional[argparse._SubParsersAction] = None):
|
||||||
|
dataclass_types = (ScriptArguments, RewardConfig, ModelConfig, DatasetMixtureConfig)
|
||||||
|
if subparsers is not None:
|
||||||
|
parser = subparsers.add_parser(
|
||||||
|
"reward", help="Run the reward training script", dataclass_types=dataclass_types
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
parser = TrlParser(dataclass_types)
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = make_parser()
|
||||||
|
# When using the trl cli, this script may be run with additional arguments, corresponding accelerate arguments.
|
||||||
|
# To ensure that their parsing does not interfere with the script arguments, parse the arguments with
|
||||||
|
# `return_remaining_strings=True`, then ignore the remaining strings.
|
||||||
|
script_args, training_args, model_args, dataset_args, _ = parser.parse_args_and_config(
|
||||||
|
return_remaining_strings=True
|
||||||
|
)
|
||||||
|
main(script_args, training_args, model_args, dataset_args)
|
55
trl/templates/rm_model_card.md
Normal file
55
trl/templates/rm_model_card.md
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
---
|
||||||
|
{{ card_data }}
|
||||||
|
---
|
||||||
|
|
||||||
|
# Model Card for {{ model_name }}
|
||||||
|
|
||||||
|
This model is a fine-tuned version of [{{ base_model }}](https://huggingface.co/{{ base_model }}){% if dataset_name %} on the [{{ dataset_name }}](https://huggingface.co/datasets/{{ dataset_name }}) dataset{% endif %}.
|
||||||
|
It has been trained using [TRL](https://github.com/huggingface/trl).
|
||||||
|
|
||||||
|
## Quick start
|
||||||
|
|
||||||
|
```python
|
||||||
|
from transformers import pipeline
|
||||||
|
|
||||||
|
text = "The capital of France is Paris."
|
||||||
|
rewarder = pipeline(model="{{ hub_model_id }}", device="cuda")
|
||||||
|
output = rewarder(text)[0]
|
||||||
|
print(output["score"])
|
||||||
|
```
|
||||||
|
|
||||||
|
## Training procedure
|
||||||
|
|
||||||
|
{% if wandb_url %}[<img src="https://raw.githubusercontent.com/wandb/assets/main/wandb-github-badge-28.svg" alt="Visualize in Weights & Biases" width="150" height="24"/>]({{ wandb_url }}){% endif %}
|
||||||
|
{% if comet_url %}[<img src="https://raw.githubusercontent.com/comet-ml/comet-examples/master/logo/comet_badge.png" alt="Visualize in Comet" width="135" height="20"/>]({{ comet_url }}){% endif %}
|
||||||
|
|
||||||
|
This model was trained with {{ trainer_name }}{% if paper_id %}, a method introduced in [{{ paper_title }}](https://huggingface.co/papers/{{ paper_id }}){% endif %}.
|
||||||
|
|
||||||
|
### Framework versions
|
||||||
|
|
||||||
|
- TRL: {{ trl_version }}
|
||||||
|
- Transformers: {{ transformers_version }}
|
||||||
|
- Pytorch: {{ pytorch_version }}
|
||||||
|
- Datasets: {{ datasets_version }}
|
||||||
|
- Tokenizers: {{ tokenizers_version }}
|
||||||
|
|
||||||
|
## Citations
|
||||||
|
|
||||||
|
{% if trainer_citation %}Cite {{ trainer_name }} as:
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
{{ trainer_citation }}
|
||||||
|
```{% endif %}
|
||||||
|
|
||||||
|
Cite TRL as:
|
||||||
|
|
||||||
|
```bibtex
|
||||||
|
{% raw %}@misc{vonwerra2022trl,
|
||||||
|
title = {{TRL: Transformer Reinforcement Learning}},
|
||||||
|
author = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang and Kashif Rasul and Quentin Gallou{\'e}dec},
|
||||||
|
year = 2020,
|
||||||
|
journal = {GitHub repository},
|
||||||
|
publisher = {GitHub},
|
||||||
|
howpublished = {\url{https://github.com/huggingface/trl}}
|
||||||
|
}{% endraw %}
|
||||||
|
```
|
@ -28,6 +28,7 @@ class BaseTrainer(Trainer):
|
|||||||
_tag_names = []
|
_tag_names = []
|
||||||
_name = "Base"
|
_name = "Base"
|
||||||
_paper = {}
|
_paper = {}
|
||||||
|
_template_file = None
|
||||||
|
|
||||||
def create_model_card(
|
def create_model_card(
|
||||||
self,
|
self,
|
||||||
@ -78,6 +79,7 @@ class BaseTrainer(Trainer):
|
|||||||
comet_url=get_comet_experiment_url(),
|
comet_url=get_comet_experiment_url(),
|
||||||
trainer_name=self._name,
|
trainer_name=self._name,
|
||||||
trainer_citation=self._paper.get("citation"),
|
trainer_citation=self._paper.get("citation"),
|
||||||
|
template_file=self._template_file,
|
||||||
paper_title=self._paper.get("title"),
|
paper_title=self._paper.get("title"),
|
||||||
paper_id=self._paper.get("id"),
|
paper_id=self._paper.get("id"),
|
||||||
)
|
)
|
||||||
|
@ -13,7 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from transformers import TrainingArguments
|
from transformers import TrainingArguments
|
||||||
|
|
||||||
@ -32,22 +32,53 @@ class RewardConfig(TrainingArguments):
|
|||||||
command line.
|
command line.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
max_length (`int` or `None`, *optional*, defaults to `1024`):
|
> Parameters that control the model
|
||||||
Maximum length of the sequences (prompt + completion) in the batch, filters out entries that exceed the
|
|
||||||
limit. This argument is required if you want to use the default data collator.
|
model_init_kwargs (`dict[str, Any]`, *optional*):
|
||||||
|
Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model`
|
||||||
|
argument of the [`RewardTrainer`] is provided as a string. If you're training a MoE architecture and want
|
||||||
|
to include the load balancing/auxilliary loss as a part of the final loss, remember to set
|
||||||
|
`output_router_logits=True` in this dictionary.
|
||||||
|
chat_template_path (`str`, *optional*):
|
||||||
|
If specified, sets the model's chat template. This can either be the path to a tokenizer (local directory
|
||||||
|
or Hugging Face Hub model) or a direct path to a Jinja template file. When using a Jinja file, you must
|
||||||
|
ensure that any special tokens referenced in the template are added to the tokenizer and that the model's
|
||||||
|
embedding layer is resized accordingly.
|
||||||
disable_dropout (`bool`, *optional*, defaults to `True`):
|
disable_dropout (`bool`, *optional*, defaults to `True`):
|
||||||
Whether to disable dropout in the model.
|
Whether to disable dropout in the model.
|
||||||
|
|
||||||
|
> Parameters that control the data preprocessing
|
||||||
|
|
||||||
dataset_num_proc (`int`, *optional*):
|
dataset_num_proc (`int`, *optional*):
|
||||||
Number of processes to use for processing the dataset.
|
Number of processes to use for processing the dataset.
|
||||||
|
eos_token (`str`, *optional*):
|
||||||
|
Token used to indicate the end of a turn or sequence. If `None`, it defaults to
|
||||||
|
`processing_class.eos_token`.
|
||||||
|
pad_token (`str`, *optional*):
|
||||||
|
Token used for padding. If `None`, it defaults to `processing_class.pad_token`, or if that is also `None`,
|
||||||
|
it falls back to `processing_class.eos_token`.
|
||||||
|
max_length (`int` or `None`, *optional*, defaults to `1024`):
|
||||||
|
Maximum length of the tokenized sequence. Samples are filtered out if either chosen or rejected sequence
|
||||||
|
exceeds this value. If `None`, no filtering is applied.
|
||||||
|
pad_to_multiple_of (`int`, *optional*):
|
||||||
|
If set, the sequences will be padded to a multiple of this value.
|
||||||
|
|
||||||
|
> Parameters that control the training
|
||||||
|
|
||||||
center_rewards_coefficient (`float`, *optional*):
|
center_rewards_coefficient (`float`, *optional*):
|
||||||
Coefficient to incentivize the reward model to output mean-zero rewards (proposed by
|
Coefficient to incentivize the reward model to output mean-zero rewards (proposed by
|
||||||
https://huggingface.co/papers/2312.09244, Eq. 2). Recommended value: `0.01`.
|
https://huggingface.co/papers/2312.09244, Eq. 2). Recommended value: `0.01`.
|
||||||
remove_unused_columns (`bool`, *optional*, defaults to `False`):
|
activation_offloading (`bool`, *optional*, defaults to `False`):
|
||||||
Whether to remove the columns that are not used by the model's forward pass. Can be `True` only if the
|
Whether to offload the activations to the CPU.
|
||||||
dataset is pretokenized.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
_VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["model_init_kwargs"]
|
||||||
|
|
||||||
# Parameters whose default values are overridden from TrainingArguments
|
# Parameters whose default values are overridden from TrainingArguments
|
||||||
|
learning_rate: float = field(
|
||||||
|
default=1e-4,
|
||||||
|
metadata={"help": "The initial learning rate for AdamW."},
|
||||||
|
)
|
||||||
logging_steps: float = field(
|
logging_steps: float = field(
|
||||||
default=10,
|
default=10,
|
||||||
metadata={
|
metadata={
|
||||||
@ -70,21 +101,59 @@ class RewardConfig(TrainingArguments):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
max_length: Optional[int] = field(
|
# Parameters that control the model
|
||||||
default=1024,
|
model_init_kwargs: Optional[dict[str, Any]] = field(
|
||||||
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Maximum length of the sequences (prompt + completion) in the batch, filters out entries that "
|
"help": "Keyword arguments for `AutoModelForCausalLM.from_pretrained`, used when the `model` argument of "
|
||||||
"exceed the limit. This argument is required if you want to use the default data collator."
|
"the `RewardTrainer` is provided as a string."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
chat_template_path: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "If specified, sets the model's chat template. This can either be the path to a tokenizer (local "
|
||||||
|
"directory or Hugging Face Hub model) or a direct path to a Jinja template file. When using a Jinja file, "
|
||||||
|
"you must ensure that any special tokens referenced in the template are added to the tokenizer and "
|
||||||
|
"that the model's embedding layer is resized accordingly."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
disable_dropout: bool = field(
|
disable_dropout: bool = field(
|
||||||
default=True,
|
default=True,
|
||||||
metadata={"help": "Whether to disable dropout in the model and reference model."},
|
metadata={"help": "Whether to disable dropout in the model."},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Parameters that control the data preprocessing
|
||||||
dataset_num_proc: Optional[int] = field(
|
dataset_num_proc: Optional[int] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Number of processes to use for processing the dataset."},
|
metadata={"help": "Number of processes to use for processing the dataset."},
|
||||||
)
|
)
|
||||||
|
eos_token: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "Token used to indicate the end of a turn or sequence. If `None`, it defaults to `processing_class.eos_token`."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
pad_token: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "Token used for padding. If `None`, it defaults to `processing_class.pad_token`, or if that "
|
||||||
|
"is also `None`, it falls back to `processing_class.eos_token`."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
max_length: Optional[int] = field(
|
||||||
|
default=1024,
|
||||||
|
metadata={
|
||||||
|
"help": "Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated from"
|
||||||
|
"the right. If `None`, no truncation is applied."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
pad_to_multiple_of: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "If set, the sequences will be padded to a multiple of this value."},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parameters that control the training
|
||||||
center_rewards_coefficient: Optional[float] = field(
|
center_rewards_coefficient: Optional[float] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
@ -92,15 +161,11 @@ class RewardConfig(TrainingArguments):
|
|||||||
"https://huggingface.co/papers/2312.09244, Eq. 2). Recommended value: `0.01`."
|
"https://huggingface.co/papers/2312.09244, Eq. 2). Recommended value: `0.01`."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
remove_unused_columns: bool = field(
|
activation_offloading: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={
|
metadata={"help": "Whether to offload the activations to the CPU."},
|
||||||
"help": "Whether to remove the columns that are not used by the model's forward pass. Can be `True` only "
|
|
||||||
"if the dataset is pretokenized."
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self.bf16 = not (self.fp16) if self.bf16 is None else self.bf16
|
self.bf16 = not (self.fp16) if self.bf16 is None else self.bf16
|
||||||
|
|
||||||
super().__post_init__()
|
super().__post_init__()
|
||||||
|
@ -12,132 +12,348 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import re
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from dataclasses import FrozenInstanceError, replace
|
from contextlib import contextmanager
|
||||||
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Optional, Union
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
import pandas as pd
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from accelerate import PartialState, logging
|
import transformers
|
||||||
from accelerate.utils import gather_object
|
from accelerate import PartialState
|
||||||
from datasets import Dataset
|
from accelerate.logging import get_logger
|
||||||
|
from datasets import Dataset, IterableDataset
|
||||||
from transformers import (
|
from transformers import (
|
||||||
BaseImageProcessor,
|
AutoModelForSequenceClassification,
|
||||||
|
AutoTokenizer,
|
||||||
DataCollator,
|
DataCollator,
|
||||||
FeatureExtractionMixin,
|
|
||||||
PreTrainedModel,
|
PreTrainedModel,
|
||||||
PreTrainedTokenizerBase,
|
PreTrainedTokenizerBase,
|
||||||
ProcessorMixin,
|
|
||||||
)
|
)
|
||||||
|
from transformers.data.data_collator import DataCollatorMixin
|
||||||
from transformers.trainer_callback import TrainerCallback
|
from transformers.trainer_callback import TrainerCallback
|
||||||
from transformers.trainer_pt_utils import nested_detach
|
|
||||||
from transformers.trainer_utils import EvalPrediction
|
from transformers.trainer_utils import EvalPrediction
|
||||||
from transformers.utils import is_peft_available, is_rich_available
|
from transformers.utils import is_peft_available
|
||||||
|
|
||||||
from ..data_utils import maybe_apply_chat_template
|
from ..data_utils import is_conversational
|
||||||
from ..models import prepare_peft_model
|
from ..models import clone_chat_template, get_act_offloading_ctx_manager, prepare_peft_model
|
||||||
from .base_trainer import BaseTrainer
|
from .base_trainer import BaseTrainer
|
||||||
from .reward_config import RewardConfig
|
from .reward_config import RewardConfig
|
||||||
from .utils import (
|
from .utils import disable_dropout_in_model, pad, remove_none_values
|
||||||
RewardDataCollatorWithPadding,
|
|
||||||
compute_accuracy,
|
|
||||||
decode_and_strip_padding,
|
|
||||||
disable_dropout_in_model,
|
|
||||||
log_table_to_comet_experiment,
|
|
||||||
print_rich_table,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if is_peft_available():
|
if is_peft_available():
|
||||||
from peft import PeftModel
|
from peft import PeftConfig, PeftModel
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _tokenize(batch: dict[str, list[Any]], tokenizer: "PreTrainedTokenizerBase") -> dict[str, list[Any]]:
|
# AutoModelForSequenceClassification adds a new classification head when loading a CausalLM. That head is randomly
|
||||||
"""Tokenize a batch from a reward modelling dataset."""
|
# initialized and triggers a harmless warning about uninitialized weights. We suppress just that specific warning to
|
||||||
new_examples = {
|
# avoid confusing users.
|
||||||
"input_ids_chosen": [],
|
@contextmanager
|
||||||
"attention_mask_chosen": [],
|
def suppress_from_pretrained_warning(logger: logging.Logger):
|
||||||
"input_ids_rejected": [],
|
pattern = re.compile(
|
||||||
"attention_mask_rejected": [],
|
r"^Some weights of \S+ were not initialized from the model checkpoint at \S+ and are newly initialized: "
|
||||||
}
|
r"\[.*\]\nYou should probably TRAIN this model on a down-stream task to be able to use it for predictions and "
|
||||||
for chosen, rejected in zip(batch["chosen"], batch["rejected"]):
|
r"inference\.$"
|
||||||
tokenized_chosen = tokenizer(chosen)
|
)
|
||||||
tokenized_rejected = tokenizer(rejected)
|
|
||||||
new_examples["input_ids_chosen"].append(tokenized_chosen["input_ids"])
|
|
||||||
new_examples["attention_mask_chosen"].append(tokenized_chosen["attention_mask"])
|
|
||||||
new_examples["input_ids_rejected"].append(tokenized_rejected["input_ids"])
|
|
||||||
new_examples["attention_mask_rejected"].append(tokenized_rejected["attention_mask"])
|
|
||||||
|
|
||||||
return new_examples
|
class _Filter(logging.Filter):
|
||||||
|
def filter(self, record: logging.LogRecord) -> bool:
|
||||||
|
return not pattern.search(record.getMessage())
|
||||||
|
|
||||||
|
f = _Filter()
|
||||||
|
logger.addFilter(f)
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
logger.removeFilter(f)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DataCollatorForPreference(DataCollatorMixin):
|
||||||
|
"""
|
||||||
|
Data collator used for preference data. Inputs are dynamically padded to the maximum length of a batch.
|
||||||
|
|
||||||
|
This collator expects each example in the input list to be a dictionary containing the `"chosen_input_ids"` and
|
||||||
|
`"rejected_input_ids"` keys. The collator returns a dictionary containing the following keys:
|
||||||
|
- `"input_ids"`: Tensor of input IDs, padded to the maximum length of the batch. The first half of the batch
|
||||||
|
corresponds to the `"chosen_input_ids"` and the second half to the `"rejected_input_ids"`.
|
||||||
|
- `"attention_mask"`: Tensor of attention mask, padded to the maximum length of the batch.
|
||||||
|
|
||||||
|
Optionally, the examples can contain a `"margin"` key, in which case the returned dictionary will also contain a
|
||||||
|
`"margin"` key with a tensor of margins.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pad_token_id (`int`):
|
||||||
|
Token ID to use for padding.
|
||||||
|
pad_to_multiple_of (`int`, *optional*):
|
||||||
|
If set, the sequences will be padded to a multiple of this value.
|
||||||
|
return_tensors (`str`, *optional*, defaults to `"pt"`):
|
||||||
|
Type of Tensor to return. Only `"pt"` is currently supported.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
```python
|
||||||
|
>>> from trl.trainer.reward_trainer import DataCollatorForPreference
|
||||||
|
|
||||||
|
>>> collator = DataCollatorForPreference(pad_token_id=0)
|
||||||
|
>>> examples = [
|
||||||
|
... {"chosen_input_ids": [1, 2, 3], "rejected_input_ids": [4, 5]},
|
||||||
|
... {"chosen_input_ids": [6, 7], "rejected_input_ids": [8]},
|
||||||
|
... ]
|
||||||
|
>>> collator(examples)
|
||||||
|
{'input_ids': tensor([[1, 2, 3],
|
||||||
|
[6, 7, 0],
|
||||||
|
[4, 5, 0],
|
||||||
|
[8, 0, 0]]),
|
||||||
|
'attention_mask': tensor([[1, 1, 1],
|
||||||
|
[1, 1, 0],
|
||||||
|
[1, 1, 0],
|
||||||
|
[1, 0, 0]])}
|
||||||
|
|
||||||
|
>>> examples = [
|
||||||
|
... {"chosen_input_ids": [1, 2, 3], "rejected_input_ids": [4, 5], "margin": 0.5},
|
||||||
|
... {"chosen_input_ids": [6, 7], "rejected_input_ids": [8], "margin": 0.0},
|
||||||
|
... ]
|
||||||
|
>>> collator(examples)
|
||||||
|
{'input_ids': tensor([[1, 2, 3],
|
||||||
|
[6, 7, 0],
|
||||||
|
[4, 5, 0],
|
||||||
|
[8, 0, 0]]),
|
||||||
|
'attention_mask': tensor([[1, 1, 1],
|
||||||
|
[1, 1, 0],
|
||||||
|
[1, 1, 0],
|
||||||
|
[1, 0, 0]]),
|
||||||
|
'margin': tensor([0.5, 0.0])}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
pad_token_id: int
|
||||||
|
pad_to_multiple_of: Optional[int] = None
|
||||||
|
return_tensors: str = "pt"
|
||||||
|
|
||||||
|
def torch_call(self, examples: list[dict[str, Any]]) -> dict[str, Any]:
|
||||||
|
# Convert to tensor
|
||||||
|
chosen_input_ids = [torch.tensor(example["chosen_input_ids"]) for example in examples]
|
||||||
|
rejected_input_ids = [torch.tensor(example["rejected_input_ids"]) for example in examples]
|
||||||
|
if "margin" in examples[0]:
|
||||||
|
margins = torch.tensor([example["margin"] for example in examples], dtype=torch.float)
|
||||||
|
input_ids = chosen_input_ids + rejected_input_ids
|
||||||
|
attention_mask = [torch.ones_like(ids) for ids in input_ids]
|
||||||
|
|
||||||
|
output = {}
|
||||||
|
|
||||||
|
# Pad
|
||||||
|
output["input_ids"] = pad(
|
||||||
|
input_ids,
|
||||||
|
padding_value=self.pad_token_id,
|
||||||
|
padding_side="right",
|
||||||
|
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||||
|
)
|
||||||
|
output["attention_mask"] = pad(
|
||||||
|
attention_mask,
|
||||||
|
padding_value=0,
|
||||||
|
padding_side="right",
|
||||||
|
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||||
|
)
|
||||||
|
if "margin" in examples[0]:
|
||||||
|
output["margin"] = margins
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
class RewardTrainer(BaseTrainer):
|
class RewardTrainer(BaseTrainer):
|
||||||
"""
|
"""
|
||||||
Trainer for custom reward.
|
Trainer for Outcome-supervised Reward Models (ORM).
|
||||||
|
|
||||||
|
This class is a wrapper around the [`~transformers.Trainer`] class and inherits all of its attributes and methods.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from trl import RewardTrainer
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
|
||||||
|
|
||||||
|
trainer = RewardTrainer(model="Qwen/Qwen2.5-0.5B-Instruct", train_dataset=dataset)
|
||||||
|
trainer.train()
|
||||||
|
```
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model ([`~transformers.PreTrainedModel`] or `torch.nn.Module`, *optional*):
|
model (`Union[str, PreTrainedModel]`):
|
||||||
Model to be trained, preferably an [`~transformers.AutoModelForSequenceClassification`].
|
Model to be trained. Can be either:
|
||||||
|
|
||||||
|
- A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or a
|
||||||
|
path to a *directory* containing model weights saved using
|
||||||
|
[`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded
|
||||||
|
using `AutoModelForSequenceClassification.from_pretrained` with the keyword arguments in
|
||||||
|
`args.model_init_kwargs`.
|
||||||
|
- A sequence classification [`~transformers.PreTrainedModel`] object.
|
||||||
args ([`RewardConfig`], *optional*):
|
args ([`RewardConfig`], *optional*):
|
||||||
Training arguments.
|
Configuration for this trainer. If `None`, a default configuration is used.
|
||||||
data_collator ([`~transformers.DataCollator`], *optional*):
|
data_collator ([`~transformers.DataCollator`], *optional*):
|
||||||
The data collator to use for training. If None is specified, the default data collator
|
Function to use to form a batch from a list of elements of the processed `train_dataset` or `eval_dataset`.
|
||||||
[`~trainer.utils.RewardDataCollatorWithPadding`] will be used which will pad the sequences to the maximum
|
Will default to [`~trainer.reward_trainer.DataCollatorForPreference`].
|
||||||
length of the sequences in the batch, given a dataset of paired sequences.
|
train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
|
||||||
train_dataset ([`~datasets.Dataset`], *optional*):
|
Dataset to use for training. This trainer supports [preference](#preference) type (both implicit and
|
||||||
The dataset to use for training.
|
explicit prompt). The format of the samples can be either:
|
||||||
eval_dataset ([`~datasets.Dataset`], *optional*):
|
|
||||||
The dataset to use for evaluation.
|
- [Standard](dataset_formats#standard): Each sample contains plain text.
|
||||||
processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*):
|
- [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
|
||||||
Processing class used to process the data. If provided, will be used to automatically process the inputs
|
and content).
|
||||||
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
|
|
||||||
reuse the fine-tuned model.
|
The trainer also supports processed datasets (tokenized) as long as they contain an `chosen_input_ids` and
|
||||||
model_init (`Callable[[], transformers.PreTrainedModel]`, *optional*):
|
`rejected_input_ids` fields.
|
||||||
The model initializer to use for training. If None is specified, the default model initializer will be
|
eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
|
||||||
used.
|
Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
|
||||||
compute_metrics (`Callable[[transformers.EvalPrediction], dict]`, *optional*, defaults to [`~trainer.utils.compute_accuracy`]):
|
processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*):
|
||||||
Function to compute metrics at evaluation. Must take in an [`~transformers.EvalPrediction`] and return a
|
Tokenizer used to process the data. If `None`, the tokenizer is loaded from the model's name with
|
||||||
dictionary string to float.
|
[`~transformers.AutoTokenizer.from_pretrained`]. A padding token, `processing_class.pad_token`, must be
|
||||||
callbacks (`list` of [`~transformers.TrainerCallback`], *optional*):
|
set. If the processing class has not set a padding token, `processing_class.eos_token` will be used as the
|
||||||
Callbacks to use during training.
|
default.
|
||||||
optimizers (`tuple` of `torch.optim.Optimizer` and `torch.optim.lr_scheduler.LambdaLR`, *optional*, defaults to `(None, None)`):
|
compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
|
||||||
Tuple containing the optimizer and the learning rate scheduler to use for training.
|
The function that will be used to compute metrics at evaluation. Must take a
|
||||||
|
[`~transformers.EvalPrediction`] and return a dictionary string to metric values. When passing
|
||||||
|
[`RewardConfig`] with `batch_eval_metrics` set to `True`, your `compute_metrics` function must take a
|
||||||
|
boolean `compute_result` argument. This will be triggered after the last eval batch to signal that the
|
||||||
|
function needs to calculate and return the global summary statistics rather than accumulating the
|
||||||
|
batch-level statistics.
|
||||||
|
callbacks (list of [`~transformers.TrainerCallback`], *optional*):
|
||||||
|
List of callbacks to customize the training loop. Will add those to the list of default callbacks detailed
|
||||||
|
in [here](https://huggingface.co/docs/transformers/main_classes/callback).
|
||||||
|
|
||||||
|
If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`]
|
||||||
|
method.
|
||||||
|
optimizers (`tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]]`, *optional*, defaults to `(None, None)`):
|
||||||
|
A tuple containing the optimizer and the scheduler to use. Will default to an instance of `AdamW` on your
|
||||||
|
model and a scheduler given by [`~transformers.get_linear_schedule_with_warmup`] controlled by `args`.
|
||||||
|
optimizer_cls_and_kwargs (`tuple[Type[torch.optim.Optimizer], Dict[str, Any]]`, *optional*):
|
||||||
|
A tuple containing the optimizer class and keyword arguments to use. Overrides `optim` and `optim_args` in
|
||||||
|
`args`. Incompatible with the `optimizers` argument.
|
||||||
|
|
||||||
|
Unlike `optimizers`, this argument avoids the need to place model parameters on the correct devices before
|
||||||
|
initializing the Trainer.
|
||||||
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*):
|
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*):
|
||||||
Function to preprocess the logits before computing the metrics. Must take in the `logits` and `labels` and
|
A function that preprocess the logits right before caching them at each evaluation step. Must take two
|
||||||
return the logits to be used for metrics computation.
|
tensors, the logits and the labels, and return the logits once processed as desired. The modifications made
|
||||||
peft_config (`dict`, *optional*):
|
by this function will be reflected in the predictions received by `compute_metrics`.
|
||||||
PEFT configuration to use PEFT for training. If `None`, PEFT is not used. If provided, the `model` will be
|
|
||||||
wrapped with the specified PEFT adapter.
|
Note that the labels (second parameter) will be `None` if the dataset does not have them.
|
||||||
|
peft_config ([`~peft.PeftConfig`], *optional*):
|
||||||
|
PEFT configuration used to wrap the model. If `None`, the model is not wrapped. Note that if the loaded
|
||||||
|
model is a causal LM, it's highly recommended to set `modules_to_save=["score"]` in the PEFT configuration
|
||||||
|
to ensure that the reward head is properly trained.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_tag_names = ["trl", "reward-trainer"]
|
_tag_names = ["trl", "reward-trainer"]
|
||||||
_name = "Reward"
|
_name = "Reward"
|
||||||
|
_template_file = "rm_model_card.md"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: Optional[Union[PreTrainedModel, nn.Module]] = None,
|
model: Union[str, PreTrainedModel],
|
||||||
args: Optional[RewardConfig] = None,
|
args: Optional[RewardConfig] = None,
|
||||||
data_collator: Optional[DataCollator] = None,
|
data_collator: Optional[DataCollator] = None,
|
||||||
train_dataset: Optional[Dataset] = None,
|
train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
|
||||||
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
||||||
processing_class: Optional[
|
processing_class: Optional[PreTrainedTokenizerBase] = None,
|
||||||
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
|
||||||
] = None,
|
|
||||||
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
|
||||||
compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
|
compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
|
||||||
callbacks: Optional[list[TrainerCallback]] = None,
|
callbacks: Optional[list[TrainerCallback]] = None,
|
||||||
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (
|
optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
|
||||||
None,
|
optimizer_cls_and_kwargs: Optional[tuple[type[torch.optim.Optimizer], dict[str, Any]]] = None,
|
||||||
None,
|
|
||||||
),
|
|
||||||
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
||||||
peft_config: Optional[dict] = None,
|
peft_config: Optional["PeftConfig"] = None,
|
||||||
):
|
):
|
||||||
|
# Args
|
||||||
|
if args is None:
|
||||||
|
model_name = model if isinstance(model, str) else model.config._name_or_path
|
||||||
|
model_name = model_name.split("/")[-1]
|
||||||
|
args = RewardConfig(f"{model_name}-Reward")
|
||||||
|
|
||||||
|
# Model
|
||||||
|
model_init_kwargs = args.model_init_kwargs or {}
|
||||||
|
if isinstance(model, str):
|
||||||
|
model_id = model
|
||||||
|
dtype = model_init_kwargs.get("dtype")
|
||||||
|
if isinstance(dtype, torch.dtype) or dtype == "auto" or dtype is None:
|
||||||
|
pass # dtype is already a torch.dtype or "auto" or None
|
||||||
|
elif isinstance(dtype, str) and dtype in ["bfloat16", "float16", "float32"]:
|
||||||
|
model_init_kwargs["dtype"] = getattr(torch, dtype)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid `dtype` passed to `RewardConfig`. Expected either 'auto' or a string representing "
|
||||||
|
f"a valid `torch.dtype` (e.g., 'float32'), but got {dtype}."
|
||||||
|
)
|
||||||
|
with suppress_from_pretrained_warning(transformers.modeling_utils.logger):
|
||||||
|
model = AutoModelForSequenceClassification.from_pretrained(model_id, num_labels=1, **model_init_kwargs)
|
||||||
|
else:
|
||||||
|
model_id = model.config._name_or_path
|
||||||
|
if args.model_init_kwargs is not None:
|
||||||
|
logger.warning(
|
||||||
|
"You passed `model_init_kwargs` to the `RewardConfig`, but your model is already instantiated. "
|
||||||
|
"The `model_init_kwargs` will be ignored."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Processing class
|
||||||
|
if processing_class is None:
|
||||||
|
processing_class = AutoTokenizer.from_pretrained(model_id)
|
||||||
|
|
||||||
|
# Handle pad token for processors or tokenizers
|
||||||
|
if args.eos_token is not None:
|
||||||
|
eos_token = args.eos_token
|
||||||
|
eos_token_id = processing_class.convert_tokens_to_ids(eos_token)
|
||||||
|
if eos_token_id is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"The specified `eos_token` ('{eos_token}') is not found in the vocabulary of the given "
|
||||||
|
f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `eos_token` exists "
|
||||||
|
"in the vocabulary before using it as an EOS token."
|
||||||
|
)
|
||||||
|
processing_class.eos_token_id = eos_token_id
|
||||||
|
|
||||||
|
if args.chat_template_path is not None:
|
||||||
|
if os.path.isfile(args.chat_template_path) and args.chat_template_path.endswith((".jinja", ".j2")):
|
||||||
|
with open(args.chat_template_path, encoding="utf-8") as chat_template_file:
|
||||||
|
processing_class.chat_template = chat_template_file.read()
|
||||||
|
added_tokens = []
|
||||||
|
else:
|
||||||
|
model, processing_class, added_tokens = clone_chat_template(
|
||||||
|
model, processing_class, args.chat_template_path
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
added_tokens = []
|
||||||
|
|
||||||
|
# PEFT configuration and model wrapping
|
||||||
|
if peft_config is not None:
|
||||||
|
if added_tokens:
|
||||||
|
# Ensure that the added tokens are trainable
|
||||||
|
if peft_config.trainable_token_indices is None:
|
||||||
|
peft_config.trainable_token_indices = {"embed_tokens": added_tokens}
|
||||||
|
elif "embed_tokens" not in peft_config.trainable_token_indices:
|
||||||
|
peft_config.trainable_token_indices["embed_tokens"] = added_tokens
|
||||||
|
else:
|
||||||
|
peft_config.trainable_token_indices["embed_tokens"].extend(added_tokens)
|
||||||
|
|
||||||
|
# Ensure that the lm_head is trainable
|
||||||
|
if peft_config.modules_to_save is None or "lm_head" not in peft_config.modules_to_save:
|
||||||
|
logger.warning(
|
||||||
|
"Cloning chat template added new tokens to the tokenizer, but 'lm_head' is not in PEFT's "
|
||||||
|
"`modules_to_save`. As a result, the model may not learn to generate outputs with these new "
|
||||||
|
"tokens, leading to degraded generation quality. To fix this, add "
|
||||||
|
"`modules_to_save=['lm_head']` to your PEFT configuration."
|
||||||
|
)
|
||||||
|
|
||||||
|
if peft_config.modules_to_save is None:
|
||||||
|
peft_config.modules_to_save = ["lm_head"]
|
||||||
|
else:
|
||||||
|
peft_config.modules_to_save.append("lm_head")
|
||||||
|
|
||||||
if peft_config is not None or (is_peft_available() and isinstance(model, PeftModel)):
|
if peft_config is not None or (is_peft_available() and isinstance(model, PeftModel)):
|
||||||
model = prepare_peft_model(model, peft_config, args)
|
model = prepare_peft_model(model, peft_config, args)
|
||||||
|
|
||||||
@ -145,78 +361,47 @@ class RewardTrainer(BaseTrainer):
|
|||||||
if args.disable_dropout:
|
if args.disable_dropout:
|
||||||
disable_dropout_in_model(model)
|
disable_dropout_in_model(model)
|
||||||
|
|
||||||
if compute_metrics is None:
|
# Pad token (needed for SequenceClassification models)
|
||||||
compute_metrics = compute_accuracy
|
# If not provided, use the one from the processing class or the eos token if the processing class does not have
|
||||||
|
# a pad token.
|
||||||
|
pad_token = args.pad_token or processing_class.pad_token or processing_class.eos_token
|
||||||
|
pad_token_id = processing_class.convert_tokens_to_ids(pad_token)
|
||||||
|
if pad_token_id is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"The specified `pad_token` ('{pad_token}') is not found in the vocabulary of the given "
|
||||||
|
f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `pad_token` exists "
|
||||||
|
"in the vocabulary before using it as a padding token."
|
||||||
|
)
|
||||||
|
model.config.pad_token_id = pad_token_id
|
||||||
|
processing_class.pad_token_id = pad_token_id
|
||||||
|
|
||||||
|
# Data collator
|
||||||
if data_collator is None:
|
if data_collator is None:
|
||||||
if processing_class is None:
|
data_collator = DataCollatorForPreference(
|
||||||
raise ValueError(
|
pad_token_id=pad_token_id,
|
||||||
"A processing_class must be specified when using the default RewardDataCollatorWithPadding"
|
pad_to_multiple_of=args.pad_to_multiple_of,
|
||||||
)
|
)
|
||||||
|
|
||||||
max_length = args.max_length
|
# Dataset
|
||||||
|
train_dataset = self._prepare_dataset(train_dataset, processing_class, args, "train")
|
||||||
|
if eval_dataset is not None:
|
||||||
|
if isinstance(eval_dataset, dict):
|
||||||
|
eval_dataset = {
|
||||||
|
key: self._prepare_dataset(dataset, processing_class, args, key)
|
||||||
|
for key, dataset in eval_dataset.items()
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
eval_dataset = self._prepare_dataset(eval_dataset, processing_class, args, "eval")
|
||||||
|
|
||||||
data_collator = RewardDataCollatorWithPadding(processing_class)
|
# Initialize the metrics
|
||||||
|
self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)}
|
||||||
|
self._total_train_tokens = 0
|
||||||
|
|
||||||
if args.remove_unused_columns:
|
# Initialize the Trainer. Parent class will handle:
|
||||||
try: # for bc before https://github.com/huggingface/transformers/pull/25435
|
# - DeepSpeed configuration (through create_accelerator_and_postprocess)
|
||||||
args.remove_unused_columns = False
|
# - FSDP setup
|
||||||
except FrozenInstanceError:
|
# - Distributed training setup
|
||||||
args = replace(args, remove_unused_columns=False)
|
# - Optimizer and scheduler creation
|
||||||
# warn users
|
|
||||||
logger.warning(
|
|
||||||
"When using RewardDataCollatorWithPadding, you should set `remove_unused_columns=False` in your RewardConfig"
|
|
||||||
" we have set it for you, but you should do it yourself in the future.",
|
|
||||||
)
|
|
||||||
|
|
||||||
self.use_reward_data_collator = True
|
|
||||||
else:
|
|
||||||
self.use_reward_data_collator = False
|
|
||||||
|
|
||||||
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
|
|
||||||
# input tensor associated with the key "input_ids". However, in Reward, the sampled data does not include the
|
|
||||||
# "input_ids" key. Instead, the available keys are "input_ids_chosen" and "input_ids_rejected". As a result,
|
|
||||||
# the trainer issues the warning: "Could not estimate the number of tokens of the input, floating-point
|
|
||||||
# operations will not be computed." To suppress this warning, we set the "estimate_tokens" key in the model's
|
|
||||||
# "warnings_issued" dictionary to True. This acts as a flag to indicate that the warning has already been
|
|
||||||
# issued.
|
|
||||||
model.warnings_issued["estimate_tokens"] = True
|
|
||||||
|
|
||||||
if "input_ids_chosen" not in train_dataset.column_names:
|
|
||||||
with PartialState().main_process_first():
|
|
||||||
fn_kwargs = {"tokenizer": processing_class}
|
|
||||||
train_dataset = train_dataset.map(maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class})
|
|
||||||
train_dataset = train_dataset.map(
|
|
||||||
_tokenize,
|
|
||||||
batched=True,
|
|
||||||
fn_kwargs=fn_kwargs,
|
|
||||||
num_proc=args.dataset_num_proc,
|
|
||||||
)
|
|
||||||
# This filter is important because otherwise you get samples that exceed the model's context length and
|
|
||||||
# get truncated => noisy signal the chosen/rejected label gets lost. The downside is that the
|
|
||||||
# user might get surprised if N samples are missing from training.
|
|
||||||
train_dataset = train_dataset.filter(
|
|
||||||
lambda x: len(x["input_ids_chosen"]) <= max_length and len(x["input_ids_rejected"]) <= max_length,
|
|
||||||
num_proc=args.dataset_num_proc,
|
|
||||||
)
|
|
||||||
if eval_dataset is not None:
|
|
||||||
eval_dataset = eval_dataset.map(
|
|
||||||
maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}
|
|
||||||
)
|
|
||||||
eval_dataset = eval_dataset.map(
|
|
||||||
_tokenize,
|
|
||||||
fn_kwargs=fn_kwargs,
|
|
||||||
batched=True,
|
|
||||||
num_proc=args.dataset_num_proc,
|
|
||||||
)
|
|
||||||
# This filter is important because otherwise you get samples that exceed the model's context length and
|
|
||||||
# get truncated => noisy signal the chosen/rejected label gets lost. The downside is that the
|
|
||||||
# user might get surprised if N samples are missing from training.
|
|
||||||
eval_dataset = eval_dataset.filter(
|
|
||||||
lambda x: len(x["input_ids_chosen"]) <= max_length
|
|
||||||
and len(x["input_ids_rejected"]) <= max_length,
|
|
||||||
num_proc=args.dataset_num_proc,
|
|
||||||
)
|
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
model=model,
|
model=model,
|
||||||
@ -225,35 +410,140 @@ class RewardTrainer(BaseTrainer):
|
|||||||
train_dataset=train_dataset,
|
train_dataset=train_dataset,
|
||||||
eval_dataset=eval_dataset,
|
eval_dataset=eval_dataset,
|
||||||
processing_class=processing_class,
|
processing_class=processing_class,
|
||||||
model_init=model_init,
|
|
||||||
compute_metrics=compute_metrics,
|
compute_metrics=compute_metrics,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
optimizers=optimizers,
|
optimizers=optimizers,
|
||||||
|
optimizer_cls_and_kwargs=optimizer_cls_and_kwargs,
|
||||||
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# During evaluation, Trainer calls compute_loss() only if can_return_loss is True and label_names is empty.
|
||||||
|
self.can_return_loss = True
|
||||||
|
self.label_names = []
|
||||||
|
|
||||||
|
# Initialize activation offloading context
|
||||||
|
if self.args.activation_offloading:
|
||||||
|
self.maybe_activation_offload_context = get_act_offloading_ctx_manager(model=self.model)
|
||||||
|
else:
|
||||||
|
self.maybe_activation_offload_context = contextlib.nullcontext()
|
||||||
|
|
||||||
# Add tags for models that have been loaded with the correct transformers version
|
# Add tags for models that have been loaded with the correct transformers version
|
||||||
if hasattr(self.model, "add_model_tags"):
|
if hasattr(self.model, "add_model_tags"):
|
||||||
self.model.add_model_tags(self._tag_names)
|
self.model.add_model_tags(self._tag_names)
|
||||||
|
|
||||||
|
self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
|
||||||
|
|
||||||
|
def _prepare_dataset(
|
||||||
|
self,
|
||||||
|
dataset: Union[Dataset, IterableDataset],
|
||||||
|
processing_class: PreTrainedTokenizerBase,
|
||||||
|
args: RewardConfig,
|
||||||
|
dataset_name: str,
|
||||||
|
) -> Union[Dataset, IterableDataset]:
|
||||||
|
# Tabular backends like Arrow/Parquet insert `None` for mismatched keys in nested structures. Clean them from
|
||||||
|
# sampled data.
|
||||||
|
if isinstance(dataset, Dataset): # IterableDataset does not support `with_transform`
|
||||||
|
dataset = dataset.with_transform(remove_none_values)
|
||||||
|
|
||||||
|
# If the dataset is already preprocessed (tokenized), skip the processing steps.
|
||||||
|
column_names = list(next(iter(dataset)).keys())
|
||||||
|
is_processed = "chosen_input_ids" in column_names and "rejected_input_ids" in column_names
|
||||||
|
|
||||||
|
# Build the kwargs for the `map` function
|
||||||
|
map_kwargs = {}
|
||||||
|
if isinstance(dataset, Dataset): # IterableDataset does not support num_proc
|
||||||
|
map_kwargs["num_proc"] = args.dataset_num_proc
|
||||||
|
|
||||||
|
with PartialState().main_process_first():
|
||||||
|
if not is_processed:
|
||||||
|
# Add EOS token to the end of the sequences if needed
|
||||||
|
first_example = next(iter(dataset))
|
||||||
|
if not is_conversational(first_example):
|
||||||
|
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
|
||||||
|
map_kwargs["desc"] = f"Adding EOS to {dataset_name} dataset"
|
||||||
|
|
||||||
|
def add_eos(example, eos_token):
|
||||||
|
if not example["chosen"].endswith(eos_token):
|
||||||
|
example["chosen"] = example["chosen"] + eos_token
|
||||||
|
if "rejected" in example and not example["rejected"].endswith(eos_token):
|
||||||
|
example["rejected"] = example["rejected"] + eos_token
|
||||||
|
return example
|
||||||
|
|
||||||
|
dataset = dataset.map(
|
||||||
|
add_eos,
|
||||||
|
fn_kwargs={"eos_token": processing_class.eos_token},
|
||||||
|
**map_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Tokenize the dataset
|
||||||
|
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
|
||||||
|
map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset"
|
||||||
|
|
||||||
|
def tokenize_fn(example, processing_class):
|
||||||
|
if "prompt" in example: # explicit prompt case
|
||||||
|
example["chosen"] = example["prompt"] + example["chosen"]
|
||||||
|
example["rejected"] = example["prompt"] + example["rejected"]
|
||||||
|
|
||||||
|
if is_conversational(example):
|
||||||
|
chosen_input_ids = processing_class.apply_chat_template(
|
||||||
|
example["chosen"],
|
||||||
|
tools=example.get("tools"),
|
||||||
|
**example.get("chat_template_kwargs", {}),
|
||||||
|
)
|
||||||
|
rejected_input_ids = processing_class.apply_chat_template(
|
||||||
|
example["rejected"],
|
||||||
|
tools=example.get("tools"),
|
||||||
|
**example.get("chat_template_kwargs", {}),
|
||||||
|
)
|
||||||
|
output = {"chosen_input_ids": chosen_input_ids, "rejected_input_ids": rejected_input_ids}
|
||||||
|
else:
|
||||||
|
output = {
|
||||||
|
"chosen_input_ids": processing_class(text=example["chosen"])["input_ids"],
|
||||||
|
"rejected_input_ids": processing_class(text=example["rejected"])["input_ids"],
|
||||||
|
}
|
||||||
|
return output
|
||||||
|
|
||||||
|
dataset = dataset.map(tokenize_fn, fn_kwargs={"processing_class": processing_class}, **map_kwargs)
|
||||||
|
|
||||||
|
# Filter samples that are longer than `max_length`
|
||||||
|
if args.max_length is not None:
|
||||||
|
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
|
||||||
|
map_kwargs["desc"] = f"Filtering {dataset_name} >{args.max_length} tokens"
|
||||||
|
dataset = dataset.filter(
|
||||||
|
lambda example: len(example["chosen_input_ids"]) <= args.max_length
|
||||||
|
and len(example["rejected_input_ids"]) <= args.max_length,
|
||||||
|
**map_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
def _set_signature_columns_if_needed(self):
|
||||||
|
# If `self.args.remove_unused_columns` is True, non-signature columns are removed.
|
||||||
|
# By default, this method sets `self._signature_columns` to the model's expected inputs (usually, "input_ids"
|
||||||
|
# and "attention_mask").
|
||||||
|
if self._signature_columns is None:
|
||||||
|
self._signature_columns = ["chosen_input_ids", "rejected_input_ids", "margin"]
|
||||||
|
|
||||||
def compute_loss(
|
def compute_loss(
|
||||||
self,
|
self,
|
||||||
model: Union[PreTrainedModel, nn.Module],
|
model: nn.Module,
|
||||||
inputs: dict[str, Union[torch.Tensor, Any]],
|
inputs: dict[str, Union[torch.Tensor, Any]],
|
||||||
return_outputs=False,
|
return_outputs: bool = False,
|
||||||
num_items_in_batch=None,
|
num_items_in_batch: Optional[torch.Tensor] = None,
|
||||||
) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]:
|
):
|
||||||
rewards_chosen = model(
|
"""
|
||||||
input_ids=inputs["input_ids_chosen"],
|
Compute training loss and additionally compute token accuracies
|
||||||
attention_mask=inputs["attention_mask_chosen"],
|
"""
|
||||||
return_dict=True,
|
mode = "train" if self.model.training else "eval"
|
||||||
)["logits"]
|
|
||||||
rewards_rejected = model(
|
# If not set, defaults from model config and may warn since cache isn't compatible with gradient checkpointing
|
||||||
input_ids=inputs["input_ids_rejected"],
|
inputs["use_cache"] = False
|
||||||
attention_mask=inputs["attention_mask_rejected"],
|
outputs = model(**inputs)
|
||||||
return_dict=True,
|
|
||||||
)["logits"]
|
# Split the rewards into chosen and rejected
|
||||||
# calculate loss, optionally modulate with margin
|
rewards_chosen, rewards_rejected = torch.chunk(outputs.logits.squeeze(-1), chunks=2)
|
||||||
|
|
||||||
|
# Calculate loss, optionally modulate with margin
|
||||||
if "margin" in inputs:
|
if "margin" in inputs:
|
||||||
loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected - inputs["margin"]).mean()
|
loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected - inputs["margin"]).mean()
|
||||||
else:
|
else:
|
||||||
@ -262,86 +552,45 @@ class RewardTrainer(BaseTrainer):
|
|||||||
if self.args.center_rewards_coefficient is not None:
|
if self.args.center_rewards_coefficient is not None:
|
||||||
loss += self.args.center_rewards_coefficient * torch.mean((rewards_chosen + rewards_rejected) ** 2)
|
loss += self.args.center_rewards_coefficient * torch.mean((rewards_chosen + rewards_rejected) ** 2)
|
||||||
|
|
||||||
if return_outputs:
|
if mode == "train":
|
||||||
return loss, {
|
num_tokens_in_batch = self.accelerator.gather_for_metrics(inputs["attention_mask"].sum()).sum().item()
|
||||||
"rewards_chosen": rewards_chosen,
|
self._total_train_tokens += num_tokens_in_batch
|
||||||
"rewards_rejected": rewards_rejected,
|
self._metrics[mode]["num_tokens"] = [self._total_train_tokens]
|
||||||
}
|
|
||||||
return loss
|
|
||||||
|
|
||||||
def prediction_step(
|
|
||||||
self,
|
|
||||||
model: Union[PreTrainedModel, nn.Module],
|
|
||||||
inputs: dict[str, Union[torch.Tensor, Any]],
|
|
||||||
prediction_loss_only: bool,
|
|
||||||
ignore_keys: Optional[list[str]] = None,
|
|
||||||
) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
|
||||||
inputs = self._prepare_inputs(inputs)
|
|
||||||
if ignore_keys is None:
|
|
||||||
if hasattr(self.model, "config"):
|
|
||||||
ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
|
|
||||||
else:
|
|
||||||
ignore_keys = []
|
|
||||||
|
|
||||||
|
# Compute min, mean, max, accuracy and margin
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
loss, logits_dict = self.compute_loss(model, inputs, return_outputs=True)
|
all_rewards = self.accelerator.gather(outputs.logits)
|
||||||
|
self._metrics[mode]["min_reward"].append(all_rewards.min().item())
|
||||||
|
self._metrics[mode]["mean_reward"].append(all_rewards.mean().item())
|
||||||
|
self._metrics[mode]["max_reward"].append(all_rewards.max().item())
|
||||||
|
|
||||||
if prediction_loss_only:
|
mean_accuracy = (rewards_chosen > rewards_rejected).float().mean()
|
||||||
return (loss, None, None)
|
mean_accuracy = self.accelerator.gather_for_metrics(mean_accuracy).mean().item()
|
||||||
|
self._metrics[mode]["accuracy"].append(mean_accuracy)
|
||||||
|
|
||||||
loss = loss.detach()
|
mean_margin = (rewards_chosen - rewards_rejected).mean()
|
||||||
logits = tuple(v for k, v in logits_dict.items() if k not in ignore_keys)
|
mean_margin = self.accelerator.gather_for_metrics(mean_margin).mean()
|
||||||
logits = nested_detach(logits)
|
self._metrics[mode]["margin"].append(mean_margin.item())
|
||||||
# Stack accepted against rejected, mean over logits
|
|
||||||
# and softmax to get preferences between accepted and rejected to sum to 1
|
|
||||||
logits = torch.stack(logits).mean(dim=2).softmax(dim=0).T
|
|
||||||
|
|
||||||
labels = torch.zeros(logits.shape[0])
|
return (loss, outputs) if return_outputs else loss
|
||||||
labels = self._prepare_inputs(labels)
|
|
||||||
|
|
||||||
return loss, logits, labels
|
# Override training step to add activation offloading context.
|
||||||
|
def training_step(self, *args, **kwargs):
|
||||||
|
with self.maybe_activation_offload_context:
|
||||||
|
return super().training_step(*args, **kwargs)
|
||||||
|
|
||||||
def evaluate(self, *args, **kwargs):
|
def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
|
||||||
num_print_samples = kwargs.pop("num_print_samples", 4)
|
mode = "train" if self.model.training else "eval"
|
||||||
self.visualize_samples(num_print_samples)
|
metrics = {key: sum(val) / len(val) for key, val in self._metrics[mode].items()} # average the metrics
|
||||||
return super().evaluate(*args, **kwargs)
|
|
||||||
|
|
||||||
def visualize_samples(self, num_print_samples: int):
|
# This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
|
||||||
"""
|
# start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
|
||||||
Visualize the reward model logits prediction
|
if mode == "eval":
|
||||||
|
metrics = {f"eval_{key}": val for key, val in metrics.items()}
|
||||||
|
|
||||||
Args:
|
logs.update(metrics)
|
||||||
num_print_samples (`int`, defaults to `4`):
|
super().log(logs, start_time)
|
||||||
The number of samples to print. Set to `-1` to print all samples.
|
self._metrics[mode].clear()
|
||||||
"""
|
|
||||||
eval_dataloader = self.get_eval_dataloader()
|
|
||||||
table = defaultdict(list)
|
|
||||||
for _, inputs in enumerate(eval_dataloader):
|
|
||||||
_, logits, _ = self.prediction_step(self.model, inputs, prediction_loss_only=False)
|
|
||||||
chosen_text = decode_and_strip_padding(inputs["input_ids_chosen"], self.processing_class)
|
|
||||||
rejected_text = decode_and_strip_padding(inputs["input_ids_rejected"], self.processing_class)
|
|
||||||
table["chosen_text"].extend(gather_object(chosen_text))
|
|
||||||
table["rejected_text"].extend(gather_object(rejected_text))
|
|
||||||
table["logits"].extend(
|
|
||||||
gather_object([[round(inner_item, 4) for inner_item in item] for item in logits.tolist()])
|
|
||||||
)
|
|
||||||
if num_print_samples >= 0 and len(table["chosen_text"]) >= num_print_samples:
|
|
||||||
break
|
|
||||||
df = pd.DataFrame(table)
|
|
||||||
if self.accelerator.process_index == 0:
|
|
||||||
if is_rich_available():
|
|
||||||
print_rich_table(df[:num_print_samples])
|
|
||||||
if "wandb" in self.args.report_to:
|
|
||||||
import wandb
|
|
||||||
|
|
||||||
if wandb.run is not None:
|
|
||||||
wandb.log({"completions": wandb.Table(dataframe=df)})
|
|
||||||
|
|
||||||
if "comet_ml" in self.args.report_to:
|
|
||||||
log_table_to_comet_experiment(
|
|
||||||
name="completions.csv",
|
|
||||||
table=df,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Ensure the model card is saved along with the checkpoint
|
# Ensure the model card is saved along with the checkpoint
|
||||||
def _save_checkpoint(self, model, trial):
|
def _save_checkpoint(self, model, trial):
|
||||||
|
@ -15,10 +15,9 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
import os
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from collections.abc import Mapping
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Optional, TypeVar, Union
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -56,6 +55,7 @@ from .utils import (
|
|||||||
entropy_from_logits,
|
entropy_from_logits,
|
||||||
flush_left,
|
flush_left,
|
||||||
pad,
|
pad,
|
||||||
|
remove_none_values,
|
||||||
selective_log_softmax,
|
selective_log_softmax,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -66,7 +66,6 @@ if is_peft_available():
|
|||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
TListOrMapping = TypeVar("TListOrMapping", list, Mapping)
|
|
||||||
|
|
||||||
FLASH_ATTENTION_VARIANTS = {
|
FLASH_ATTENTION_VARIANTS = {
|
||||||
"flash_attention_2",
|
"flash_attention_2",
|
||||||
@ -77,38 +76,6 @@ FLASH_ATTENTION_VARIANTS = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def remove_none_values(example: TListOrMapping) -> TListOrMapping:
|
|
||||||
"""
|
|
||||||
Recursively removes entries with `None` values from a nested structure (list or dictionary).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
example (`list` or `Mapping`):
|
|
||||||
Input nested structure (list or dictionary) from which to remove `None`.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
```python
|
|
||||||
>>> [
|
|
||||||
... {
|
|
||||||
... "a": {"aa": None, "ab": 1},
|
|
||||||
... "b": "my_string",
|
|
||||||
... }
|
|
||||||
... ]
|
|
||||||
>>> remove_none_values(example)
|
|
||||||
[{'a': {'ab': 1}, 'b': 'my_string'}]
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
if isinstance(example, list):
|
|
||||||
return [remove_none_values(value) if isinstance(value, (dict, list)) else value for value in example]
|
|
||||||
elif isinstance(example, Mapping):
|
|
||||||
return {
|
|
||||||
key: remove_none_values(value) if isinstance(value, (dict, list)) else value
|
|
||||||
for key, value in example.items()
|
|
||||||
if value is not None
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
raise TypeError("Input must be a list or a dictionary.")
|
|
||||||
|
|
||||||
|
|
||||||
def get_dataset_column_names(dataset: Union[Dataset, IterableDataset]) -> list[str]:
|
def get_dataset_column_names(dataset: Union[Dataset, IterableDataset]) -> list[str]:
|
||||||
return list(next(iter(dataset)).keys()) if dataset.column_names is None else dataset.column_names
|
return list(next(iter(dataset)).keys()) if dataset.column_names is None else dataset.column_names
|
||||||
|
|
||||||
|
@ -18,11 +18,12 @@ import json
|
|||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import socket
|
import socket
|
||||||
from collections.abc import Sequence, Sized
|
import warnings
|
||||||
|
from collections.abc import Mapping, Sequence, Sized
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from importlib.metadata import version
|
from importlib.metadata import version
|
||||||
from itertools import accumulate
|
from itertools import accumulate
|
||||||
from typing import Any, Literal, Optional, Union
|
from typing import Any, Literal, Optional, TypeVar, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
@ -219,6 +220,10 @@ class RewardDataCollatorWithPadding:
|
|||||||
r"""
|
r"""
|
||||||
Reward DataCollator class that pads the inputs to the maximum length of the batch.
|
Reward DataCollator class that pads the inputs to the maximum length of the batch.
|
||||||
|
|
||||||
|
> [!WARNING]
|
||||||
|
> This class is deprecated and will be removed in version 0.27.0. Please use
|
||||||
|
`trl.trainer.reward_trainer.DataCollatorForPreference` instead.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
tokenizer (`PreTrainedTokenizerBase`):
|
tokenizer (`PreTrainedTokenizerBase`):
|
||||||
The tokenizer used for encoding the data.
|
The tokenizer used for encoding the data.
|
||||||
@ -235,6 +240,14 @@ class RewardDataCollatorWithPadding:
|
|||||||
pad_to_multiple_of: Optional[int] = None
|
pad_to_multiple_of: Optional[int] = None
|
||||||
return_tensors: str = "pt"
|
return_tensors: str = "pt"
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs) -> None:
|
||||||
|
warnings.warn(
|
||||||
|
"The `RewardDataCollatorWithPadding` is deprecated and will be removed in version 0.27.0. Please use "
|
||||||
|
"`trl.trainer.reward_trainer.DataCollatorForPreference` instead.",
|
||||||
|
DeprecationWarning,
|
||||||
|
)
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:
|
def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:
|
||||||
features_chosen = []
|
features_chosen = []
|
||||||
features_rejected = []
|
features_rejected = []
|
||||||
@ -1241,6 +1254,10 @@ def decode_and_strip_padding(inputs: torch.Tensor, tokenizer: PreTrainedTokenize
|
|||||||
"""
|
"""
|
||||||
Decodes the input tensor and strips the padding tokens.
|
Decodes the input tensor and strips the padding tokens.
|
||||||
|
|
||||||
|
> [!WARNING]
|
||||||
|
> This function is deprecated and will be removed in a version 0.25.0. If you want to keep using it, please copy
|
||||||
|
> the code into your codebase and use it from there.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
inputs (`torch.Tensor`):
|
inputs (`torch.Tensor`):
|
||||||
The input tensor to be decoded.
|
The input tensor to be decoded.
|
||||||
@ -1251,6 +1268,11 @@ def decode_and_strip_padding(inputs: torch.Tensor, tokenizer: PreTrainedTokenize
|
|||||||
`list[str]`:
|
`list[str]`:
|
||||||
The list of decoded strings with padding tokens stripped.
|
The list of decoded strings with padding tokens stripped.
|
||||||
"""
|
"""
|
||||||
|
warnings.warn(
|
||||||
|
"The function `decode_and_strip_padding` is deprecated and will be removed in a version 0.25.0. If you want "
|
||||||
|
"to keep using it, please copy the code into your codebase and use it from there.",
|
||||||
|
DeprecationWarning,
|
||||||
|
)
|
||||||
decoded = tokenizer.batch_decode(inputs, skip_special_tokens=False)
|
decoded = tokenizer.batch_decode(inputs, skip_special_tokens=False)
|
||||||
return [d.replace(tokenizer.pad_token, "") for d in decoded]
|
return [d.replace(tokenizer.pad_token, "") for d in decoded]
|
||||||
|
|
||||||
@ -1264,6 +1286,7 @@ def generate_model_card(
|
|||||||
wandb_url: Optional[str],
|
wandb_url: Optional[str],
|
||||||
trainer_name: str,
|
trainer_name: str,
|
||||||
trainer_citation: Optional[str] = None,
|
trainer_citation: Optional[str] = None,
|
||||||
|
template_file: Optional[str] = None,
|
||||||
paper_title: Optional[str] = None,
|
paper_title: Optional[str] = None,
|
||||||
paper_id: Optional[str] = None,
|
paper_id: Optional[str] = None,
|
||||||
comet_url: Optional[str] = None,
|
comet_url: Optional[str] = None,
|
||||||
@ -1290,6 +1313,8 @@ def generate_model_card(
|
|||||||
Trainer name.
|
Trainer name.
|
||||||
trainer_citation (`str` or `None`, defaults to `None`):
|
trainer_citation (`str` or `None`, defaults to `None`):
|
||||||
Trainer citation as a BibTeX entry.
|
Trainer citation as a BibTeX entry.
|
||||||
|
template_file (`str` *optional*):
|
||||||
|
Template file name located in the `trl/templates` directory. Defaults to `lm_model_card.md`.
|
||||||
paper_title (`str` or `None`, defaults to `None`):
|
paper_title (`str` or `None`, defaults to `None`):
|
||||||
Paper title.
|
Paper title.
|
||||||
paper_id (`str` or `None`, defaults to `None`):
|
paper_id (`str` or `None`, defaults to `None`):
|
||||||
@ -1307,9 +1332,10 @@ def generate_model_card(
|
|||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
tags=["generated_from_trainer", *tags],
|
tags=["generated_from_trainer", *tags],
|
||||||
)
|
)
|
||||||
|
template_file = template_file or "lm_model_card.md"
|
||||||
card = ModelCard.from_template(
|
card = ModelCard.from_template(
|
||||||
card_data,
|
card_data,
|
||||||
template_path=str(pkg_resources.files("trl").joinpath("templates/lm_model_card.md")),
|
template_path=str(pkg_resources.files("trl").joinpath(f"templates/{template_file}")),
|
||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
hub_model_id=hub_model_id,
|
hub_model_id=hub_model_id,
|
||||||
@ -1956,6 +1982,41 @@ def truncate_with_protected_tokens(
|
|||||||
return torch.stack(truncated_seq), torch.stack(truncated_mask)
|
return torch.stack(truncated_seq), torch.stack(truncated_mask)
|
||||||
|
|
||||||
|
|
||||||
|
TListOrMapping = TypeVar("TListOrMapping", list, Mapping)
|
||||||
|
|
||||||
|
|
||||||
|
def remove_none_values(example: TListOrMapping) -> TListOrMapping:
|
||||||
|
"""
|
||||||
|
Recursively removes entries with `None` values from a nested structure (list or dictionary).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
example (`list` or `Mapping`):
|
||||||
|
Input nested structure (list or dictionary) from which to remove `None`.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
>>> [
|
||||||
|
... {
|
||||||
|
... "a": {"aa": None, "ab": 1},
|
||||||
|
... "b": "my_string",
|
||||||
|
... }
|
||||||
|
... ]
|
||||||
|
>>> remove_none_values(example)
|
||||||
|
[{'a': {'ab': 1}, 'b': 'my_string'}]
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
if isinstance(example, list):
|
||||||
|
return [remove_none_values(value) if isinstance(value, (dict, list)) else value for value in example]
|
||||||
|
elif isinstance(example, Mapping):
|
||||||
|
return {
|
||||||
|
key: remove_none_values(value) if isinstance(value, (dict, list)) else value
|
||||||
|
for key, value in example.items()
|
||||||
|
if value is not None
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise TypeError("Input must be a list or a dictionary.")
|
||||||
|
|
||||||
|
|
||||||
def create_model_from_path(model_id: str, **kwargs) -> PreTrainedModel:
|
def create_model_from_path(model_id: str, **kwargs) -> PreTrainedModel:
|
||||||
"""
|
"""
|
||||||
Create a model from a given path using the specified initialization arguments.
|
Create a model from a given path using the specified initialization arguments.
|
||||||
|
Reference in New Issue
Block a user