🎁 RewardTrainer refactor (#4093)

Co-authored-by: juejuezi <juejuezi.git@foxmail.com>
Co-authored-by: Yi Shi <96773624+singing-cat@users.noreply.github.com>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
This commit is contained in:
Quentin Gallouédec
2025-09-30 15:13:45 -06:00
committed by GitHub
parent ebb8899f5d
commit da209f89fc
19 changed files with 1982 additions and 529 deletions

View File

@ -136,23 +136,13 @@ trainer.train()
Here is a basic example of how to use the [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer): Here is a basic example of how to use the [`RewardTrainer`](https://huggingface.co/docs/trl/reward_trainer):
```python ```python
from trl import RewardConfig, RewardTrainer from trl import RewardTrainer
from datasets import load_dataset from datasets import load_dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct")
model = AutoModelForSequenceClassification.from_pretrained(
"Qwen/Qwen2.5-0.5B-Instruct", num_labels=1
)
model.config.pad_token_id = tokenizer.pad_token_id
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train") dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
training_args = RewardConfig(output_dir="Qwen2.5-0.5B-Reward", per_device_train_batch_size=2)
trainer = RewardTrainer( trainer = RewardTrainer(
args=training_args, model="Qwen/Qwen2.5-0.5B-Instruct",
model=model,
processing_class=tokenizer,
train_dataset=dataset, train_dataset=dataset,
) )
trainer.train() trainer.train()

View File

@ -9,6 +9,7 @@ Currently supported commands are:
- `trl dpo`: fine-tune a LLM with DPO - `trl dpo`: fine-tune a LLM with DPO
- `trl grpo`: fine-tune a LLM with GRPO - `trl grpo`: fine-tune a LLM with GRPO
- `trl kto`: fine-tune a LLM with KTO - `trl kto`: fine-tune a LLM with KTO
- `trl reward`: train a Reward Model
- `trl rloo`: fine-tune a LLM with RLOO - `trl rloo`: fine-tune a LLM with RLOO
- `trl sft`: fine-tune a LLM with SFT - `trl sft`: fine-tune a LLM with SFT
@ -41,6 +42,15 @@ trl dpo \
--dataset_name anthropic/hh-rlhf --dataset_name anthropic/hh-rlhf
``` ```
</hfoption>
<hfoption id="Reward">
```bash
trl reward \
--model_name_or_path Qwen/Qwen2.5-0.5B \
--dataset_name trl-lib/ultrafeedback_binarized
```
</hfoption> </hfoption>
</hfoptions> </hfoptions>
@ -78,6 +88,21 @@ Launch with:
trl dpo --config dpo_config.yaml trl dpo --config dpo_config.yaml
``` ```
</hfoption>
<hfoption id="Reward">
```yaml
# reward_config.yaml
model_name_or_path: Qwen/Qwen2.5-0.5B
dataset_name: trl-lib/ultrafeedback_binarized
```
Launch with:
```bash
trl reward --config reward_config.yaml
```
</hfoption> </hfoption>
</hfoptions> </hfoptions>
@ -138,6 +163,33 @@ Launch with:
```bash ```bash
trl dpo --config dpo_config.yaml trl dpo --config dpo_config.yaml
``` ```
</hfoption>
<hfoption id="Reward inline">
```bash
trl reward \
--model_name_or_path Qwen/Qwen2.5-0.5B \
--dataset_name trl-lib/ultrafeedback_binarized \
--num_processes 4
```
</hfoption>
<hfoption id="Reward w/ config file">
```yaml
# reward_config.yaml
model_name_or_path: Qwen/Qwen2.5-0.5B
dataset_name: trl-lib/ultrafeedback_binarized
num_processes: 4
```
Launch with:
```bash
trl reward --config reward_config.yaml
```
</hfoption> </hfoption>
</hfoptions> </hfoptions>
@ -217,6 +269,33 @@ Launch with:
```bash ```bash
trl dpo --config dpo_config.yaml trl dpo --config dpo_config.yaml
``` ```
</hfoption>
<hfoption id="Reward inline">
```bash
trl reward \
--model_name_or_path Qwen/Qwen2.5-0.5B \
--dataset_name trl-lib/ultrafeedback_binarized \
--accelerate_config zero2 # or path/to/my/accelerate/config.yaml
```
</hfoption>
<hfoption id="Reward w/ config file">
```yaml
# reward_config.yaml
model_name_or_path: Qwen/Qwen2.5-0.5B
dataset_name: trl-lib/ultrafeedback_binarized
accelerate_config: zero2 # or path/to/my/accelerate/config.yaml
```
Launch with:
```bash
trl reward --config reward_config.yaml
```
</hfoption> </hfoption>
</hfoptions> </hfoptions>
@ -224,7 +303,7 @@ trl dpo --config dpo_config.yaml
You can use dataset mixtures to combine multiple datasets into a single training dataset. This is useful for training on diverse data sources or when you want to mix different types of data. You can use dataset mixtures to combine multiple datasets into a single training dataset. This is useful for training on diverse data sources or when you want to mix different types of data.
<hfoptions id="accelerate_config"> <hfoptions id="dataset_mixtures">
<hfoption id="SFT"> <hfoption id="SFT">
```yaml ```yaml
@ -258,6 +337,23 @@ Launch with:
trl dpo --config dpo_config.yaml trl dpo --config dpo_config.yaml
``` ```
</hfoption>
<hfoption id="Reward">
```yaml
# reward_config.yaml
model_name_or_path: Qwen/Qwen2.5-0.5B
datasets:
- path: trl-lib/tldr-preference
- path: trl-lib/lm-human-preferences-sentiment
```
Launch with:
```bash
trl reward --config reward_config.yaml
```
</hfoption> </hfoption>
</hfoptions> </hfoptions>

View File

@ -533,3 +533,53 @@ training_args = CPOConfig(
... ...
) )
``` ```
## Reward Modeling
Papers relating to the [`RewardTrainer`]
### Helping or Herding? Reward Model Ensembles Mitigate but do not Eliminate Reward Hacking
**📜 Paper**: https://huggingface.co/papers/2312.09244
This paper proposed an auxiliary loss function designed to directly learn a centered reward model. This auxiliary loss minimizes the squared sum of the rewards, encouraging the model to naturally produce mean-zero outputs and thereby resolving the issue of underdetermination.
$$
\mathcal{L}(\theta) = - \mathbb{E}_{(x,y^+,y^-) \sim \mathcal{D}} \left[ \log \sigma(r_\theta(x, y^+) - r_\theta(x, y^-)) \textcolor{red}{- \eta \cdot (r_\theta(x, y^+) + r_\theta(x, y^-))^2} \right].
$$
To use this auxiliary loss with [`RewardTrainer`], you can use the `center_rewards_coefficient` argument in [`RewardConfig`] as follows:
```python
from trl import RewardConfig
training_args = RewardConfig(
center_rewards_coefficient=0.01, # η in the paper
...
)
```
### Llama 2: Open Foundation and Fine-Tuned Chat Models
**📜 Paper**: https://huggingface.co/papers/2307.09288
In this paper, the authors propose to leverage their preference ratings being decomposed as a scale of four points (e.g., _significantly better_) to provide more informative feedback to the reward model. This is done by adding a margin to the loss function, which encourages the reward model to assign larger gaps in scores for pairs with higher preference ratings.
$$
\mathcal{L}(\theta) = - \mathbb{E}_{(x,y^+,y^-,\textcolor{red}{m}) \sim \mathcal{D}} \left[ \log \sigma(r_\theta(x, y^+) - r_\theta(x, y^-) \textcolor{red}{- m}) \right].
$$
You can add a margin to the loss by adding a `margin` column to the dataset. The following example shows how to set up a the "Margin Small" setting of the paper.
```python
def add_margin(example):
preference_to_margin = {
"significantly better": 1.0,
"better": 2.0/3.0,
"slightly better": 1.0/3.0,
"negligibly better / unsure": 0.0,
}
return {"margin": preference_to_margin[example["preference_label"]]}
dataset = dataset.map(add_margin)
```

View File

@ -1,6 +1,6 @@
# Quickstart # Quickstart
TRL is a comprehensive library for post-training foundation models using techniques like Supervised Fine-Tuning (SFT), Group Relative Policy Optimization (GRPO), Direct Preference Optimization (DPO). TRL is a comprehensive library for post-training foundation models using techniques like Supervised Fine-Tuning (SFT), Group Relative Policy Optimization (GRPO), Direct Preference Optimization (DPO).
## Quick Examples ## Quick Examples
@ -51,6 +51,21 @@ trainer = DPOTrainer(
trainer.train() trainer.train()
``` ```
### Reward Modeling
```python
from trl import RewardTrainer
from datasets import load_dataset
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
trainer = RewardTrainer(
model="Qwen/Qwen2.5-0.5B-Instruct",
train_dataset=dataset,
)
trainer.train()
```
## Command Line Interface ## Command Line Interface
Skip the code entirely - train directly from your terminal: Skip the code entirely - train directly from your terminal:
@ -63,6 +78,10 @@ trl sft --model_name_or_path Qwen/Qwen2.5-0.5B \
# DPO: Align with preferences # DPO: Align with preferences
trl dpo --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \ trl dpo --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
--dataset_name trl-lib/ultrafeedback_binarized --dataset_name trl-lib/ultrafeedback_binarized
# Reward: Train a reward model
trl reward --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
--dataset_name trl-lib/ultrafeedback_binarized
``` ```
## What's Next? ## What's Next?

View File

@ -2,84 +2,225 @@
[![](https://img.shields.io/badge/All_models-Reward_Trainer-blue)](https://huggingface.co/models?other=reward-trainer,trl) [![](https://img.shields.io/badge/All_models-Reward_Trainer-blue)](https://huggingface.co/models?other=reward-trainer,trl)
TRL supports custom reward modeling for anyone to perform reward modeling on their dataset and model. ## Overview
Check out a complete flexible example at [`examples/scripts/reward_modeling.py`](https://github.com/huggingface/trl/tree/main/examples/scripts/reward_modeling.py). TRL supports the Outcome-supervised Reward Modeling (ORM) Trainer for training reward models.
## Expected dataset type This post-training method was contributed by [Younes Belkada](https://huggingface.co/ybelkada).
The [`RewardTrainer`] requires a [*implicit prompt* preference dataset](dataset_formats#preference). It means that the dataset should only contain the columns `"chosen"` and `"rejected"` (and not `"prompt"`). ## Quick start
The [`RewardTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
You can also use a pretokenized dataset, in which case the dataset should contain the following columns: `input_ids_chosen`, `attention_mask_chosen`, `input_ids_rejected` and `attention_mask_rejected`. This example demonstrates how to train a reward model using the [`RewardTrainer`] from TRL. We train a [Qwen 3 0.6B](https://huggingface.co/Qwen/Qwen3-0.6B) model on the [UltraFeedback dataset](https://huggingface.co/datasets/trl-lib/ultrafeedback_binarized), large-scale, fine-grained, diverse preference dataset.
## Using the `RewardTrainer`
After preparing your dataset, you can use the [`RewardTrainer`] in the same way as the `Trainer` class from 🤗 Transformers.
You should pass an `AutoModelForSequenceClassification` model to the [`RewardTrainer`], along with a [`RewardConfig`] which configures the hyperparameters of the training.
### Leveraging 🤗 PEFT to train a reward model
Just pass a `peft_config` in the keyword arguments of [`RewardTrainer`], and the trainer should automatically take care of converting the model into a PEFT model!
```python ```python
from peft import LoraConfig, TaskType from trl import RewardTrainer
from transformers import AutoModelForSequenceClassification, AutoTokenizer from datasets import load_dataset
from trl import RewardTrainer, RewardConfig
model = AutoModelForSequenceClassification.from_pretrained("gpt2")
peft_config = LoraConfig(
task_type=TaskType.SEQ_CLS,
inference_mode=False,
r=8,
lora_alpha=32,
lora_dropout=0.1,
)
...
trainer = RewardTrainer( trainer = RewardTrainer(
model=model, model="Qwen/Qwen3-0.6B",
args=training_args, train_dataset=load_dataset("trl-lib/ultrafeedback_binarized", split="train"),
processing_class=tokenizer, )
trainer.train()
```
<iframe src="https://trl-lib-trackio.hf.space/?project=trl-documentation&metrics=train*&sidebar=hidden&runs=reward_qwen3-0.6B_ultrafeedback2" style="width: 100%; min-width: 300px; max-width: 800px;" height="830" frameBorder="0"></iframe>
## Expected dataset type and format
[`RewardTrainer`] supports [preference](dataset_formats#preference) datasets type (both implicit and explicit prompt). The [`RewardTrainer`] is compatible with both [standard](dataset_formats#standard) and [conversational](dataset_formats#conversational) dataset formats. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
```python
# Standard preference (implicit prompt)
{"chosen": "The sky is blue.",
"rejected": "The sky is green."}
# Conversational preference (implicit prompt)
{"chosen": [{"role": "user", "content": "What color is the sky?"},
{"role": "assistant", "content": "It is blue."}],
"rejected": [{"role": "user", "content": "What color is the sky?"},
{"role": "assistant", "content": "It is green."}]}
# Standard preference (explicit prompt)
{"prompt": "The sky is",
"chosen": " blue.",
"rejected": " green."}
# Conversational preference (explicit prompt)
{"prompt": [{"role": "user", "content": "What color is the sky?"}],
"chosen": [{"role": "assistant", "content": "It is blue."}],
"rejected": [{"role": "assistant", "content": "It is green."}]}
```
If your dataset is not in one of these formats, you can preprocess it to convert it into the expected format. Here is an example with the [lmarena-ai/arena-human-preference-55k](https://huggingface.co/datasets/lmarena-ai/arena-human-preference-55k) dataset:
```python
from datasets import load_dataset
import json
dataset = load_dataset("lmarena-ai/arena-human-preference-55k")
# Filter out ties
dataset = dataset.filter(lambda example: example["winner_tie"] == 0)
# Create 'chosen' and 'rejected' fields based on the winner column
def response_a_b_to_chosen_rejected(example):
if example["winner_model_a"] == 1:
example["chosen"] = example["response_a"]
example["rejected"] = example["response_b"]
else:
example["chosen"] = example["response_b"]
example["rejected"] = example["response_a"]
return example
dataset = dataset.map(response_a_b_to_chosen_rejected)
# Convert to conversational format
def make_conversation(example):
prompt = json.loads(example["prompt"])[0] # '["What color is the sky?"]' -> "What color is the sky?"
chosen = json.loads(example["chosen"])[0]
rejected = json.loads(example["rejected"])[0]
return {
"chosen": [{"role": "user", "content": prompt}, {"role": "assistant", "content": chosen}],
"rejected": [{"role": "user", "content": prompt}, {"role": "assistant", "content": rejected}],
}
dataset = dataset.map(make_conversation)
# Keep only necessary columns
dataset = dataset.select_columns(["chosen", "rejected"])
print(next(iter(dataset["train"])))
```
```json
{
"chosen": [
{"role": "user", "content": "Is it morally right to try to have a certain percentage of females on managerial positions?"},
{"role": "assistant", "content": "The question of whether it is morally right to aim for a certain percentage of females..."},
],
"rejected": [
{"role": "user", "content": "Is it morally right to try to have a certain percentage of females on managerial positions?"},
{"role": "assistant", "content": "As an AI, I don't have personal beliefs or opinions. However, ..."},
],
}
```
## Looking deeper into the training method
Reward Models (RMs) are typically trained using supervised learning on datasets containing pairs of preferred and non-preferred responses. The goal is to learn a function that assigns higher scores to preferred responses, enabling the model to rank outputs based on preferences.
This section breaks down how reward modeling works in practice, covering the key steps: **preprocessing** and **loss computation**.
### Preprocessing and tokenization
During training, each example is expected to contain a **chosen** and **rejected** field. For more details on the expected formats, see [Dataset formats - Preference](dataset_formats#preference).
The [`RewardTrainer`] tokenizes each input using the model's tokenizer. If prompts and completions (chosen and rejected) are provided separately (explicit prompt case), they are concatenated before tokenization.
### Computing the loss
Let \\( x \\) be the input sequence (prompt) and \\( y^+ \\) and \\( y^- \\) be the chosen and rejected sequences respectively. Under the Bradley-Terry model ([Bradley & Terry, 1952](https://www.jstor.org/stable/2334029)), the probability that \\( y^+ \\) is preferred over \\( y^- \\) given a reward function \\( r \\) is \\( p(y^+ ≻ y^- |x) = \sigma(r(x, y^+)r(x, y^-)) \\), where \\( σ \\) is the sigmoid function.
The reward model \\( r_\theta(x, y) \\) is trained to assign higher scores to preferred responses \\( y^+ \\) over non-preferred ones \\( y^- \\). The loss is then defined as the negative log-likelihood of the observed preferences:
$$
\mathcal{L}(\theta) = - \mathbb{E}_{(x,y^+,y^-) \sim \mathcal{D}} \left[ \log \sigma(r_\theta(x, y^+) - r_\theta(x, y^-)) \right].
$$
> [!TIP]
> The Bradley-Terry model is underdetermined, meaning that adding a constant to all rewards does not change the preference probabilities. To address this, [Helping or Herding? Reward Model Ensembles Mitigate but do not Eliminate Reward Hacking](https://huggingface.co/papers/2312.09244) proposes adding an auxiliary loss term that encourages the rewards to be centered around zero. This is controlled by the `center_rewards_coefficient` parameter in the [`RewardConfig`]. The recommended value is `1e-2`.
## Logged metrics
While training and evaluating we record the following reward metrics:
* `global_step`: The total number of optimizer steps taken so far.
* `epoch`: The current epoch number, based on dataset iteration.
* `num_tokens`: The total number of tokens processed so far.
* `loss`: The average loss over the last logging interval.
* `accuracy`: The proportion of correct predictions (i.e., the model assigned a higher score to the chosen response than to the rejected one) averaged over the last logging interval.
* `min_reward`: The minimum reward score assigned by the model. This value is averaged over the logging interval.
* `mean_reward`: The average reward score assigned by the model over the last logging interval.
* `max_reward`: The maximum reward score assigned by the model. This value is averaged over the logging interval.
* `margin`: The average margin (difference between chosen and rejected rewards) over the last logging interval.
* `learning_rate`: The current learning rate, which may change dynamically if a scheduler is used.
* `grad_norm`: The L2 norm of the gradients, computed before gradient clipping.
## Customization
### Model initialization
You can directly pass the kwargs of the [`~transformers.AutoModelForSequenceClassification.from_pretrained()`] method to the [`RewardConfig`]. For example, if you want to load a model in a different precision, analogous to
```python
model = AutoModelForSequenceClassification.from_pretrained("Qwen/Qwen3-0.6B", dtype=torch.bfloat16)
```
you can do so by passing the `model_init_kwargs={"dtype": torch.bfloat16}` argument to the [`RewardConfig`].
```python
from trl import RewardConfig
training_args = RewardConfig(
model_init_kwargs={"dtype": torch.bfloat16},
)
```
Note that all keyword arguments of [`~transformers.AutoModelForSequenceClassification.from_pretrained()`] are supported, except for `num_labels`, which is automatically set to 1.
### Train adapters with PEFT
We support tight integration with 🤗 PEFT library, allowing any user to conveniently train adapters and share them on the Hub, rather than training the entire model.
```python
from datasets import load_dataset
from trl import RewardTrainer
from peft import LoraConfig
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
trainer = RewardTrainer(
"Qwen/Qwen3-4B",
train_dataset=dataset, train_dataset=dataset,
peft_config=peft_config, peft_config=LoraConfig(modules_to_save=["score"]) # important to include the score head when base model is not a sequence classification model
) )
trainer.train() trainer.train()
``` ```
### Adding a margin to the loss You can also continue training your [`~peft.PeftModel`]. For that, first load a `PeftModel` outside [`RewardTrainer`] and pass it directly to the trainer without the `peft_config` argument being passed.
As in the [Llama 2 paper](https://huggingface.co/papers/2307.09288), you can add a margin to the loss by adding a `margin` column to the dataset. The reward collator will automatically pass it through and the loss will be computed accordingly.
```python ```python
def add_margin(row): from datasets import load_dataset
# Assume you have a score_chosen and score_rejected columns that you want to use to compute the margin from trl import RewardTrainer
return {'margin': row['score_chosen'] - row['score_rejected']} from peft import AutoPeftModelForCausalLM
dataset = dataset.map(add_margin) model = AutoPeftModelForCausalLM.from_pretrained("trl-lib/Qwen3-4B-Reward-LoRA", is_trainable=True)
``` dataset = load_dataset("trl-lib/Capybara", split="train")
### Centering rewards trainer = RewardTrainer(
model=model,
In many scenarios, it's preferable to ensure that a reward model's output is mean zero. This is often done by first calculating the model's average score and then subtracting it. train_dataset=dataset,
[[Eisenstein et al., 2023]](https://huggingface.co/papers/2312.09244) proposed an auxiliary loss function designed to directly learn a centered reward model. This auxiliary loss minimizes the squared sum of the rewards, encouraging the model to naturally produce mean-zero outputs:
$$\Big( R(p, r_1) + R(p, r_2) \Big)^2 $$
This auxiliary loss is combined with the main loss function, weighted by the parameter `center_rewards_coefficient` in the `[RewardConfig]`. By default, this feature is deactivated (`center_rewards_coefficient = None`).
```python
training_args = RewardConfig(
center_rewards_coefficient=0.01,
...
) )
trainer.train()
``` ```
For reference results, please refer PR [#1932](https://github.com/huggingface/trl/pull/1932). > [!TIP]
> When training adapters, you typically use a higher learning rate (≈1e3) since only new parameters are being learned.
>
> ```python
> RewardConfig(learning_rate=1e-3, ...)
> ```
## Tool Calling with Reward Modeling
The [`RewardTrainer`] fully supports fine-tuning models with _tool calling_ capabilities. In this case, each dataset example should include:
* The conversation messages, including any tool calls (`tool_calls`) and tool responses (`tool` role messages)
* The list of available tools in the `tools` column, typically provided as JSON schemas
For details on the expected dataset structure, see the [Dataset Format — Tool Calling](dataset_formats#tool-calling) section.
## RewardTrainer ## RewardTrainer
@ -91,3 +232,7 @@ For reference results, please refer PR [#1932](https://github.com/huggingface/tr
## RewardConfig ## RewardConfig
[[autodoc]] RewardConfig [[autodoc]] RewardConfig
## DataCollatoForPreference
[[autodoc]] trainer.reward_trainer.DataCollatorForPreference

View File

@ -23,7 +23,7 @@ trainer = SFTTrainer(
trainer.train() trainer.train()
``` ```
<iframe src="https://trl-lib-trackio.hf.space/?project=trl-documentation&metrics=train/loss,train/mean_token_accuracy,train/num_tokens&sidebar=hidden" style="width: 100%; min-width: 300px; max-width: 800px;" height="830" frameBorder="0"></iframe> <iframe src="https://trl-lib-trackio.hf.space/?project=trl-documentation&metrics=train*&runs=sft_qwen3-0.6B_capybara" style="width: 100%; min-width: 300px; max-width: 800px;" height="830" frameBorder="0"></iframe>
## Expected dataset type and format ## Expected dataset type and format

View File

@ -64,4 +64,4 @@ trainer.train()
will give you a hosted dashboard at https://huggingface.co/spaces/trl-lib/trackio. will give you a hosted dashboard at https://huggingface.co/spaces/trl-lib/trackio.
<iframe src="https://trl-lib-trackio.hf.space/?project=trl-documentation&sidebar=hidden" style="width: 100%; min-width: 300px; max-width: 800px;" height="830" frameBorder="0"></iframe> <iframe src="https://trl-lib-trackio.hf.space/?project=trl-documentation&sidebar=hidden&runs=sft_qwen3-0.6B_capybara" style="width: 100%; min-width: 300px; max-width: 800px;" height="830" frameBorder="0"></iframe>

View File

@ -43,6 +43,7 @@ from transformers import (
GPT2LMHeadModel, GPT2LMHeadModel,
GPTNeoXConfig, GPTNeoXConfig,
GPTNeoXForCausalLM, GPTNeoXForCausalLM,
GPTNeoXForSequenceClassification,
GptOssConfig, GptOssConfig,
GptOssForCausalLM, GptOssForCausalLM,
Idefics2Config, Idefics2Config,
@ -73,6 +74,7 @@ from transformers import (
Qwen3ForSequenceClassification, Qwen3ForSequenceClassification,
Qwen3MoeConfig, Qwen3MoeConfig,
Qwen3MoeForCausalLM, Qwen3MoeForCausalLM,
Qwen3MoeForSequenceClassification,
SmolVLMForConditionalGeneration, SmolVLMForConditionalGeneration,
T5ForConditionalGeneration, T5ForConditionalGeneration,
) )
@ -234,22 +236,46 @@ model = Qwen3ForCausalLM(config)
push_to_hub(model, tokenizer, "small") push_to_hub(model, tokenizer, "small")
# Reward models # Reward models
for model_id, config_class, model_class, suffix in [ for model_id, model_class, suffix in [
("meta-llama/Llama-3.2-1B-Instruct", LlamaConfig, LlamaForSequenceClassification, "3.2"), ("EleutherAI/pythia-14m", GPTNeoXForSequenceClassification, None),
("Qwen/Qwen2.5-32B-Instruct", Qwen2Config, Qwen2ForSequenceClassification, "2.5"), ("meta-llama/Llama-3.2-1B-Instruct", LlamaForSequenceClassification, "3.2"),
("Qwen/Qwen3-4B", Qwen3Config, Qwen3ForSequenceClassification, None), ("Qwen/Qwen2.5-32B-Instruct", Qwen2ForSequenceClassification, "2.5"),
("Qwen/Qwen3-4B", Qwen3ForSequenceClassification, None),
]: ]:
tokenizer = AutoTokenizer.from_pretrained(model_id) tokenizer = AutoTokenizer.from_pretrained(model_id)
config = config_class( kwargs = {
vocab_size=len(tokenizer.vocab), "num_labels": 1,
hidden_size=8, "hidden_size": 16,
num_attention_heads=4, "num_attention_heads": 4,
num_key_value_heads=2, "num_key_value_heads": 2,
num_hidden_layers=2, "num_hidden_layers": 2,
intermediate_size=32, "intermediate_size": 32,
num_labels=1, }
) config = AutoConfig.from_pretrained(model_id, **kwargs)
model = model_class(config) # Bug in transformers: it ignores num_hidden_layers to build layer_types
if model_id in ("Qwen/Qwen2.5-32B-Instruct", "Qwen/Qwen3-4B"):
config.layer_types = config.layer_types[:2]
model = model_class(config).to(dtype=torch.bfloat16)
init_weights_tiny_model(model)
push_to_hub(model, tokenizer, "tiny", suffix)
# MoE Reward models
for model_id, model_class, suffix in [
("Qwen/Qwen3-30B-A3B", Qwen3MoeForSequenceClassification, None),
]:
tokenizer = AutoTokenizer.from_pretrained(model_id)
kwargs = {
"num_labels": 1,
"hidden_size": 16,
"num_attention_heads": 4,
"num_key_value_heads": 2,
"num_hidden_layers": 2,
"intermediate_size": 32,
"num_experts": 4,
"num_experts_per_tok": 2,
}
config = AutoConfig.from_pretrained(model_id, **kwargs)
model = model_class(config).to(dtype=torch.bfloat16)
push_to_hub(model, tokenizer, "tiny", suffix) push_to_hub(model, tokenizer, "tiny", suffix)
@ -315,7 +341,5 @@ for model_id, model_class in [
kwargs["perceiver_config"] = {"hidden_size": 16} kwargs["perceiver_config"] = {"hidden_size": 16}
config = AutoConfig.from_pretrained(model_id, text_config=text_config, vision_config=vision_config, **kwargs) config = AutoConfig.from_pretrained(model_id, text_config=text_config, vision_config=vision_config, **kwargs)
model = model_class(config).to(dtype=torch.bfloat16) model = model_class(config).to(dtype=torch.bfloat16)
push_to_hub(model, processor, "tiny") push_to_hub(model, processor, "tiny")

View File

@ -67,6 +67,13 @@ class TestCLI(TrlTestCase):
with patch("sys.argv", command.split(" ")): with patch("sys.argv", command.split(" ")):
main() main()
def test_reward(self):
from trl.cli import main
command = f"trl reward --output_dir {self.tmp_dir} --model_name_or_path trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5 --dataset_name trl-internal-testing/zen --dataset_config standard_implicit_prompt_preference --report_to none"
with patch("sys.argv", command.split(" ")):
main()
def test_rloo(self): def test_rloo(self):
from trl.cli import main from trl.cli import main

View File

@ -12,217 +12,823 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import pathlib
import unittest
import torch import torch
from datasets import Dataset, load_dataset from datasets import load_dataset
from parameterized import parameterized
from transformers import AutoModelForSequenceClassification, AutoTokenizer from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers.testing_utils import require_peft from transformers.testing_utils import require_peft
from transformers.utils import is_peft_available from transformers.utils import is_peft_available
from trl import RewardConfig, RewardTrainer, maybe_apply_chat_template from trl import RewardConfig, RewardTrainer
from trl.trainer.reward_trainer import _tokenize from trl.trainer.reward_trainer import DataCollatorForPreference
from .testing_utils import TrlTestCase from .testing_utils import TrlTestCase
if is_peft_available(): if is_peft_available():
from peft import LoraConfig, TaskType from peft import LoraConfig, PeftModel, get_peft_model
class TestDataCollatorForPreference(TrlTestCase):
def test_basic_padding(self):
"""Test basic padding functionality without completion masks."""
self.collator = DataCollatorForPreference(pad_token_id=0)
examples = [
{"chosen_input_ids": [1, 2, 3], "rejected_input_ids": [4, 5]},
{"chosen_input_ids": [6, 7], "rejected_input_ids": [8]},
]
result = self.collator(examples)
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [6, 7, 0], [4, 5, 0], [8, 0, 0]]))
torch.testing.assert_close(
result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0], [1, 1, 0], [1, 0, 0]])
)
def test_pad_to_multiple_of(self):
"""Test padding to multiple of specified value."""
collator = DataCollatorForPreference(pad_token_id=0, pad_to_multiple_of=4)
examples = [
{"chosen_input_ids": [1, 2, 3], "rejected_input_ids": [4, 5]},
{"chosen_input_ids": [6, 7], "rejected_input_ids": [8]},
]
result = collator(examples)
torch.testing.assert_close(
result["input_ids"], torch.tensor([[1, 2, 3, 0], [6, 7, 0, 0], [4, 5, 0, 0], [8, 0, 0, 0]])
)
torch.testing.assert_close(
result["attention_mask"], torch.tensor([[1, 1, 1, 0], [1, 1, 0, 0], [1, 1, 0, 0], [1, 0, 0, 0]])
)
def test_single_example(self):
"""Test collator with a single example."""
self.collator = DataCollatorForPreference(pad_token_id=0)
examples = [{"chosen_input_ids": [1, 2, 3], "rejected_input_ids": [4, 5]}]
result = self.collator(examples)
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 0]]))
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]]))
def test_different_pad_token_id(self):
"""Test with different pad token ID."""
collator = DataCollatorForPreference(pad_token_id=999)
examples = [
{"chosen_input_ids": [1, 2, 3], "rejected_input_ids": [4, 5]},
{"chosen_input_ids": [6, 7], "rejected_input_ids": [8]},
]
result = collator(examples)
torch.testing.assert_close(
result["input_ids"], torch.tensor([[1, 2, 3], [6, 7, 999], [4, 5, 999], [8, 999, 999]])
)
torch.testing.assert_close(
result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0], [1, 1, 0], [1, 0, 0]])
)
def test_collate_with_margin(self):
self.collator = DataCollatorForPreference(pad_token_id=0)
examples = [
{"chosen_input_ids": [1, 2, 3], "rejected_input_ids": [4, 5], "margin": 0.1},
{"chosen_input_ids": [6, 7], "rejected_input_ids": [8], "margin": 0.2},
]
result = self.collator(examples)
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [6, 7, 0], [4, 5, 0], [8, 0, 0]]))
torch.testing.assert_close(
result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0], [1, 1, 0], [1, 0, 0]])
)
torch.testing.assert_close(result["margin"], torch.tensor([0.1, 0.2]))
class RewardTrainerTester(TrlTestCase): class RewardTrainerTester(TrlTestCase):
def setUp(self): @parameterized.expand(
super().setUp() [
self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" ("trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",),
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) ("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification",),
self.model = AutoModelForSequenceClassification.from_pretrained(self.model_id) ("trl-internal-testing/tiny-LlamaForSequenceClassification-3.2",),
self.model.config.pad_token_id = self.tokenizer.pad_token_id ]
)
def test_train(self, model_id):
# Get the dataset
dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train")
def test_preprocessing_conversational(self): # Initialize the trainer
dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_preference", split="train")
training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none") training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none")
trainer = RewardTrainer( trainer = RewardTrainer(model=model_id, args=training_args, train_dataset=dataset)
model=self.model, args=training_args, processing_class=self.tokenizer, train_dataset=dummy_dataset
)
dummy_dataset = dummy_dataset.map(maybe_apply_chat_template, fn_kwargs={"tokenizer": self.tokenizer})
dummy_dataset = dummy_dataset.map(_tokenize, batched=True, fn_kwargs={"tokenizer": self.tokenizer})
self.assertDictEqual(trainer.train_dataset[:], dummy_dataset[:])
def test_preprocessing_standard(self): # Save the initial parameters to compare them later
# No chat template, so we load a fresh tokenizer
tokenizer = AutoTokenizer.from_pretrained(self.model_id)
dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train")
training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none")
trainer = RewardTrainer(
model=self.model, args=training_args, processing_class=tokenizer, train_dataset=dummy_dataset
)
dummy_dataset = dummy_dataset.map(_tokenize, batched=True, fn_kwargs={"tokenizer": tokenizer})
self.assertDictEqual(trainer.train_dataset[:], dummy_dataset[:])
def test_train_full(self):
dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_preference", split="train")
training_args = RewardConfig(output_dir=self.tmp_dir, max_steps=3, report_to="none")
trainer = RewardTrainer(
model=self.model, args=training_args, processing_class=self.tokenizer, train_dataset=dummy_dataset
)
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
# Train the model
trainer.train() trainer.train()
# Check that the training loss is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# Check that the parameters have changed
# Check the params have changed
for n, param in previous_trainable_params.items(): for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n) new_param = trainer.model.get_parameter(n)
if param.sum() != 0: # ignore 0 biases self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12))
def test_train_full_pretokenized(self): @parameterized.expand(
dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_preference", split="train") [
dummy_dataset = dummy_dataset.map(maybe_apply_chat_template, fn_kwargs={"tokenizer": self.tokenizer}) ("standard_preference",),
dummy_dataset = dummy_dataset.map(_tokenize, batched=True, fn_kwargs={"tokenizer": self.tokenizer}) ("conversational_preference",),
training_args = RewardConfig(output_dir=self.tmp_dir, max_steps=3, report_to="none") ("standard_implicit_prompt_preference",),
("conversational_implicit_prompt_preference",),
]
)
def test_train_dataset_types(self, config_name):
# Get the dataset
dataset = load_dataset("trl-internal-testing/zen", config_name, split="train")
# Initialize the trainer
training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none")
trainer = RewardTrainer( trainer = RewardTrainer(
model=self.model, args=training_args, processing_class=self.tokenizer, train_dataset=dummy_dataset model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
args=training_args,
train_dataset=dataset,
) )
# Save the initial parameters to compare them later
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
# Train the model
trainer.train() trainer.train()
# Check that the training loss is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# Check that the parameters have changed
# Check the params have changed
for n, param in previous_trainable_params.items(): for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n) new_param = trainer.model.get_parameter(n)
if param.sum() != 0: # ignore 0 biases self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12))
def test_train_model(self):
# Instantiate the model
model = AutoModelForSequenceClassification.from_pretrained(
"trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5"
)
# Get the dataset
dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train")
# Initialize the trainer
training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none")
trainer = RewardTrainer(model=model, args=training_args, train_dataset=dataset)
# Save the initial parameters to compare them later
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
# Train the model
trainer.train()
# Check that the training loss is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
def test_train_from_causal_lm(self):
# Get the dataset
dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train")
# Initialize the trainer
training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none")
trainer = RewardTrainer(
model="trl-internal-testing/tiny-Qwen3ForCausalLM", args=training_args, train_dataset=dataset
)
# Save the initial parameters to compare them later
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
# Train the model
trainer.train()
# Check that the training loss is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
def test_train_model_dtype(self):
# Get the dataset
dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train")
# Initialize the trainer
training_args = RewardConfig(
output_dir=self.tmp_dir,
model_init_kwargs={"dtype": torch.float16},
learning_rate=0.1,
report_to="none",
)
trainer = RewardTrainer(
model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
args=training_args,
train_dataset=dataset,
)
# Save the initial parameters to compare them later
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
# Train the model
trainer.train()
# Check that the training loss is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# Check the params have changed
for n, param in previous_trainable_params.items():
# For some reasonn model.layers.0.input_layernorm.weight doesn't change in GitHub Actions but does
# locally. We ignore this parameter for now
if "layernorm" in n:
continue
new_param = trainer.model.get_parameter(n)
# Check the torch dtype
self.assertEqual(new_param.dtype, torch.float16)
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
@require_peft @require_peft
def test_train_lora(self): def test_train_dense_with_peft_config(self):
peft_config = LoraConfig( # Get the base model parameter names
task_type=TaskType.SEQ_CLS, model_id = "trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5"
inference_mode=False, model = AutoModelForSequenceClassification.from_pretrained(model_id)
r=8, base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()]
lora_alpha=32,
lora_dropout=0.1, # Get the dataset
) dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train")
dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_preference", split="train")
training_args = RewardConfig(output_dir=self.tmp_dir, max_steps=3, report_to="none") # Initialize the trainer
training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none")
trainer = RewardTrainer( trainer = RewardTrainer(
model=self.model, model=model_id,
args=training_args, args=training_args,
processing_class=self.tokenizer, train_dataset=dataset,
train_dataset=dummy_dataset, peft_config=LoraConfig(),
peft_config=peft_config,
) )
previous_trainable_params = {}
previous_non_trainable_params = {}
# due to a change in the way the modules to save are dealt in PEFT. # Save the initial parameters to compare them later
trainable_params_name = ["lora", "modules_to_save"] previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
# check gradients are not None
for n, param in trainer.model.named_parameters():
if any(t in n for t in trainable_params_name):
previous_trainable_params[n] = param.clone()
else:
previous_non_trainable_params[n] = param.clone()
# Train the model
trainer.train() trainer.train()
self.assertIsNotNone(trainer.state.log_history[(-1)]["train_loss"]) # Check that the training loss is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# Check that the parameters have changed # Check the peft params have changed and the base model params have not changed
for n, param in previous_trainable_params.items(): for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n) new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.allclose(param, new_param, atol=1e-12, rtol=1e-12)) if n in base_param_names: # We expect the base model parameters to be the same
self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed")
# Check that the non trainable parameters have not changed elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer)
for n, param in previous_non_trainable_params.items(): self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
new_param = trainer.model.get_parameter(n)
self.assertTrue(torch.allclose(param, new_param, atol=1e-12, rtol=1e-12))
@require_peft @require_peft
def test_train_lora_pretokenized(self): def test_train_moe_with_peft_config(self):
peft_config = LoraConfig( # Get the base model parameter names
task_type=TaskType.SEQ_CLS, model_id = "trl-internal-testing/tiny-Qwen3MoeForSequenceClassification"
inference_mode=False, model = AutoModelForSequenceClassification.from_pretrained(model_id)
r=8, base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()]
lora_alpha=32,
lora_dropout=0.1, # Get the dataset
) dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train")
dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_preference", split="train")
dummy_dataset = dummy_dataset.map(maybe_apply_chat_template, fn_kwargs={"tokenizer": self.tokenizer}) # Initialize the trainer
dummy_dataset = dummy_dataset.map(_tokenize, batched=True, fn_kwargs={"tokenizer": self.tokenizer}) training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none")
training_args = RewardConfig(output_dir=self.tmp_dir, max_steps=3, report_to="none")
trainer = RewardTrainer( trainer = RewardTrainer(
model=self.model, model=model_id,
args=training_args, args=training_args,
processing_class=self.tokenizer, train_dataset=dataset,
train_dataset=dummy_dataset, peft_config=LoraConfig(target_modules=["up_proj", "down_proj", "score"]),
peft_config=peft_config,
) )
previous_trainable_params = {}
previous_non_trainable_params = {}
# due to a change in the way the modules to save are dealt in PEFT. # Save the initial parameters to compare them later
trainable_params_name = ["lora", "modules_to_save"] previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
# check gradients are not None
for n, param in trainer.model.named_parameters():
if any(t in n for t in trainable_params_name):
previous_trainable_params[n] = param.clone()
else:
previous_non_trainable_params[n] = param.clone()
# Train the model
trainer.train() trainer.train()
self.assertIsNotNone(trainer.state.log_history[(-1)]["train_loss"]) # Check that the training loss is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# Check that the parameters have changed # Check the peft params have changed and the base model params have not changed
for n, param in previous_trainable_params.items(): for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n) new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.allclose(param, new_param, atol=1e-12, rtol=1e-12)) if n in base_param_names: # We expect the base model parameters to be the same
self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed")
elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer)
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
# Check that the non trainable parameters have not changed @require_peft
for n, param in previous_non_trainable_params.items(): def test_train_peft_model(self):
# Get the base model
model_id = "trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5"
model = AutoModelForSequenceClassification.from_pretrained(model_id)
# Get the base model parameter names
base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()]
# Turn the model into a peft model
lora_config = LoraConfig()
model = get_peft_model(model, lora_config)
# Get the dataset
dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train")
# Initialize the trainer
training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none")
trainer = RewardTrainer(model=model, args=training_args, train_dataset=dataset)
# Save the initial parameters to compare them later
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
# Train the model
trainer.train()
# Check that the training loss is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# Check the peft params have changed and the base model params have not changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n) new_param = trainer.model.get_parameter(n)
self.assertTrue(torch.allclose(param, new_param, atol=1e-12, rtol=1e-12)) if n in base_param_names: # We expect the base model parameters to be the same
self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed")
elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer)
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
@require_peft
def test_train_dense_with_peft_config_and_gradient_checkpointing(self):
# Get the base model parameter names
model_id = "trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5"
model = AutoModelForSequenceClassification.from_pretrained(model_id)
base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()]
# Get the dataset
dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train")
# Initialize the trainer
training_args = RewardConfig(output_dir=self.tmp_dir, gradient_checkpointing=True, report_to="none")
def test_margin(self):
dummy_dataset_dict = {
"input_ids_chosen": [
torch.LongTensor([0, 1, 2]),
],
"attention_mask_chosen": [
torch.LongTensor([1, 1, 1]),
],
"input_ids_rejected": [
torch.LongTensor([0, 2]),
],
"attention_mask_rejected": [
torch.LongTensor([1, 1]),
],
"margin": [
torch.FloatTensor([1.0]),
],
}
dummy_dataset = Dataset.from_dict(dummy_dataset_dict)
training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none")
trainer = RewardTrainer( trainer = RewardTrainer(
model=self.model, args=training_args, processing_class=self.tokenizer, train_dataset=dummy_dataset model=model_id,
args=training_args,
train_dataset=dataset,
peft_config=LoraConfig(),
) )
batch = [dummy_dataset[0]] # Save the initial parameters to compare them later
batch = trainer.data_collator(batch) previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
batch = {k: v.to(trainer.model.device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
loss, outputs = trainer.compute_loss(trainer.model, batch, return_outputs=True)
l_val = -torch.nn.functional.logsigmoid( # Train the model
outputs["rewards_chosen"] - outputs["rewards_rejected"] - batch["margin"] trainer.train()
).mean()
self.assertLess(abs(loss - l_val), 1e-6) # Check that the training loss is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
def test_tags(self): # Check the peft params have changed and the base model params have not changed
dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_preference", split="train") for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if n in base_param_names: # We expect the base model parameters to be the same
self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed")
elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer)
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
@require_peft
def test_train_moe_with_peft_config_and_gradient_checkpointing(self):
# Get the base model parameter names
model_id = "trl-internal-testing/tiny-Qwen3MoeForSequenceClassification"
model = AutoModelForSequenceClassification.from_pretrained(model_id)
base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()]
# Get the dataset
dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train")
# Initialize the trainer
training_args = RewardConfig(output_dir=self.tmp_dir, gradient_checkpointing=True, report_to="none")
trainer = RewardTrainer(
model=model_id,
args=training_args,
train_dataset=dataset,
peft_config=LoraConfig(target_modules=["up_proj", "down_proj", "score"]),
)
# Save the initial parameters to compare them later
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
# Train the model
trainer.train()
# Check that the training loss is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# Check the peft params have changed and the base model params have not changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if n in base_param_names: # We expect the base model parameters to be the same
self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed")
elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer)
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
@require_peft
def test_train_with_peft_model_and_gradient_checkpointing(self):
# Get the base model parameter names
model_id = "trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5"
model = AutoModelForSequenceClassification.from_pretrained(model_id)
base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()]
model = get_peft_model(model, LoraConfig())
# Get the dataset
dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train")
# Initialize the trainer
training_args = RewardConfig(output_dir=self.tmp_dir, gradient_checkpointing=True, report_to="none")
trainer = RewardTrainer(model=model, args=training_args, train_dataset=dataset)
# Verify model is a PeftModel
self.assertIsInstance(trainer.model, PeftModel)
# Save the initial parameters to compare them later
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
# Train the model
trainer.train()
# Check that the training loss is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# Check the peft params have changed and the base model params have not changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
if n in base_param_names: # We expect the base model parameters to be the same
self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed")
elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer)
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
def test_train_with_pretokenized_data(self):
# Get the dataset
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
tokenizer = AutoTokenizer.from_pretrained(model_id)
dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train")
def tokenize_example(example):
return {
"chosen_input_ids": tokenizer(example["chosen"]).input_ids,
"rejected_input_ids": tokenizer(example["rejected"]).input_ids,
}
# Apply tokenization
tokenized_dataset = dataset.map(tokenize_example, remove_columns=["chosen", "rejected"])
# Initialize the trainer
training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none")
trainer = RewardTrainer(model=model_id, args=training_args, train_dataset=tokenized_dataset)
# Save the initial parameters to compare them later
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
# Train the model
trainer.train()
# Check that the training loss is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
def test_train_with_iterable_dataset(self):
# Get the dataset
dataset = load_dataset(
"trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train", streaming=True
)
# Initialize the trainer
training_args = RewardConfig(output_dir=self.tmp_dir, max_steps=3, report_to="none")
trainer = RewardTrainer(
model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
args=training_args,
train_dataset=dataset,
)
# Save the initial parameters to compare them later
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
# Train the model
trainer.train()
# Check that the training loss is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
def test_train_with_chat_template_kwargs(self):
# Get the dataset
dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train")
# Initialize the trainer
training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none")
tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5")
# The following template is a simplified version of the Qwen chat template, where an additional argument
# `role_capital` is used to control the capitalization of roles.
tokenizer.chat_template = '{%- if messages[0]["role"] == "system" -%} {{ "<|im_start|>" + ("SYSTEM" if role_capital else "system") + "\\n" + messages[0]["content"] + "<|im_end|>\\n" }}{%- else -%} {{ "<|im_start|>" + ("SYSTEM" if role_capital else "system") + "\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n" }}{%- endif -%}{%- for message in messages -%} {%- if (message.role == "user") or (message.role == "system" and not loop.first) or (message.role == "assistant" and not message.tool_calls) -%} {{ "<|im_start|>" + (message.role.upper() if role_capital else message.role) + "\\n" + message.content + "<|im_end|>\\n" }} {%- elif message.role == "assistant" -%} {{ "<|im_start|>" + ("ASSISTANT" if role_capital else "assistant") }} {%- if message.content -%} {{ "\\n" + message.content }} {%- endif -%} {{ "<|im_end|>\\n" }} {%- elif message.role == "tool" -%} {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") -%} {{ "<|im_start|>" + ("USER" if role_capital else "user") }} {%- endif -%} {{ "\\n<tool_response>\\n" + message.content + "\\n</tool_response>" }} {%- if loop.last or (messages[loop.index0 + 1].role != "tool") -%} {{ "<|im_end|>\\n" }} {%- endif -%} {%- endif -%}{%- endfor -%}{%- if add_generation_prompt -%} {{ "<|im_start|>" + ("ASSISTANT" if role_capital else "assistant") + "\\n" }}{%- endif -%}'
dataset.add_column("chat_template_kwargs", [{"role_capital": bool(i % 2)} for i in range(len(dataset))])
trainer = RewardTrainer(
model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
args=training_args,
train_dataset=dataset,
)
# Save the initial parameters to compare them later
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
# Train the model
trainer.train()
# Check that the training loss is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
def test_train_with_set_chat_template_from_model(self):
# Get the dataset
dataset = load_dataset("trl-internal-testing/zen", "conversational_preference", split="train")
# Initialize the trainer
training_args = RewardConfig(output_dir=self.tmp_dir, chat_template_path="Qwen/Qwen3-4B", report_to="none")
# trl-internal-testing/tiny-GPTNeoXForSequenceClassification doesn't have a chat template set by default
trainer = RewardTrainer(
model="trl-internal-testing/tiny-GPTNeoXForSequenceClassification",
args=training_args,
train_dataset=dataset,
)
# Save the initial parameters to compare them later
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
# Train the model
trainer.train()
# Check that the training loss is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
# RewardTrainer uses a mean-free loss that cancels uniform shifts in output scores. Since GPT-NeoX models
# include a final LayerNorm, its bias consistently receives zero gradient and remains unchanged, so we skip
# this parameter.
if n == "gpt_neox.final_layer_norm.bias":
continue
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
def test_train_with_set_chat_template_from_path(self):
# Get the dataset
dataset = load_dataset("trl-internal-testing/zen", "conversational_preference", split="train")
# Initialize the trainer
training_args = RewardConfig(
output_dir=self.tmp_dir,
chat_template_path=str(pathlib.Path(__file__).parent / "data" / "template.jinja"),
report_to="none",
)
# trl-internal-testing/tiny-GPTNeoXForSequenceClassification doesn't have a chat template set by default
trainer = RewardTrainer(
model="trl-internal-testing/tiny-GPTNeoXForSequenceClassification",
args=training_args,
train_dataset=dataset,
)
# Save the initial parameters to compare them later
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
# Train the model
trainer.train()
# Check that the training loss is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
# RewardTrainer uses a mean-free loss that cancels uniform shifts in output scores. Since GPT-NeoX models
# include a final LayerNorm, its bias consistently receives zero gradient and remains unchanged, so we skip
# this parameter.
if n == "gpt_neox.final_layer_norm.bias":
continue
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
# Check that the template saved in the output directory is the same as the one used for training
template_path = pathlib.Path(self.tmp_dir) / "checkpoint-9" / "chat_template.jinja"
self.assertTrue(template_path.exists(), f"Chat template not found at {template_path}")
with open(template_path) as f:
template_content = f.read()
with open(training_args.chat_template_path) as f:
original_template_content = f.read()
self.assertEqual(
template_content, original_template_content, "Chat template content does not match the original"
)
@unittest.skip("Skipping until we have a dataset with tool calls")
def test_train_toolcall_data(self):
# Get the dataset
dataset = load_dataset("trl-internal-testing/toolcall", split="train")
# Initialize the trainer
training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none") training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none")
trainer = RewardTrainer( trainer = RewardTrainer(
model=self.model, args=training_args, processing_class=self.tokenizer, train_dataset=dummy_dataset model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
args=training_args,
train_dataset=dataset,
) )
self.assertEqual(trainer.model.model_tags, trainer._tag_names)
# Save the initial parameters to compare them later
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
# Train the model
trainer.train()
# Check that the training loss is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
def test_train_with_eval(self):
# Get the dataset
dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference")
# Initialize the trainer
training_args = RewardConfig(output_dir=self.tmp_dir, eval_strategy="steps", eval_steps=3, report_to="none")
trainer = RewardTrainer(
model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
args=training_args,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
)
# Train the model
trainer.train()
# Check that the eval loss is not None
self.assertIsNotNone(trainer.state.log_history[0]["eval_loss"])
def test_train_with_multiple_eval_dataset(self):
# Get the dataset
dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference")
# Initialize the trainer
training_args = RewardConfig(output_dir=self.tmp_dir, eval_strategy="steps", eval_steps=3, report_to="none")
trainer = RewardTrainer(
model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
args=training_args,
train_dataset=dataset["train"],
eval_dataset={"data1": dataset["test"], "data2": dataset["test"]},
)
# Train the model
trainer.train()
# Check that the eval losses are not None
self.assertIsNotNone(trainer.state.log_history[-3]["eval_data1_loss"])
self.assertIsNotNone(trainer.state.log_history[-2]["eval_data2_loss"])
def test_train_with_gradient_checkpointing(self):
# Get the dataset
dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train")
# Initialize the trainer
training_args = RewardConfig(output_dir=self.tmp_dir, gradient_checkpointing=True, report_to="none")
trainer = RewardTrainer(
model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
args=training_args,
train_dataset=dataset,
)
# Save the initial parameters to compare them later
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
# Train the model
trainer.train()
# Check that the training loss is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
def test_tag_added(self):
# Get the dataset
dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train")
# Initialize the trainer
trainer = RewardTrainer(
model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
train_dataset=dataset,
)
for tag in ["reward-trainer", "trl"]:
self.assertIn(tag, trainer.model.model_tags)
@require_peft
def test_tag_added_peft(self):
# Get the dataset
dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train")
# Initialize the trainer
trainer = RewardTrainer(
model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
train_dataset=dataset,
peft_config=LoraConfig(),
)
for tag in ["reward-trainer", "trl"]:
self.assertIn(tag, trainer.model.model_tags)
def test_train_with_margin(self):
# Get the dataset
dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train")
def add_margin(example):
# dummy margin based on the length of the chosen summary
return {"margin": len(example["chosen"])}
dataset = dataset.map(add_margin)
# Initialize the trainer
training_args = RewardConfig(output_dir=self.tmp_dir, report_to="none")
trainer = RewardTrainer(
model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
args=training_args,
train_dataset=dataset,
)
# Save the initial parameters to compare them later
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
# Train the model
trainer.train()
# Check that the training loss is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")
def test_train_with_center_rewards_coefficient(self):
# Get the dataset
dataset = load_dataset("trl-internal-testing/zen", "standard_implicit_prompt_preference", split="train")
# Initialize the trainer
training_args = RewardConfig(output_dir=self.tmp_dir, center_rewards_coefficient=0.01, report_to="none")
trainer = RewardTrainer(
model="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
args=training_args,
train_dataset=dataset,
)
# Save the initial parameters to compare them later
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
# Train the model
trainer.train()
# Check that the training loss is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed")

View File

@ -24,6 +24,7 @@ from .scripts.dpo import make_parser as make_dpo_parser
from .scripts.env import print_env from .scripts.env import print_env
from .scripts.grpo import make_parser as make_grpo_parser from .scripts.grpo import make_parser as make_grpo_parser
from .scripts.kto import make_parser as make_kto_parser from .scripts.kto import make_parser as make_kto_parser
from .scripts.reward import make_parser as make_reward_parser
from .scripts.rloo import make_parser as make_rloo_parser from .scripts.rloo import make_parser as make_rloo_parser
from .scripts.sft import make_parser as make_sft_parser from .scripts.sft import make_parser as make_sft_parser
from .scripts.utils import TrlParser from .scripts.utils import TrlParser
@ -45,6 +46,7 @@ def main():
subparsers.add_parser("env", help="Print the environment information") subparsers.add_parser("env", help="Print the environment information")
make_grpo_parser(subparsers) make_grpo_parser(subparsers)
make_kto_parser(subparsers) make_kto_parser(subparsers)
make_reward_parser(subparsers)
make_rloo_parser(subparsers) make_rloo_parser(subparsers)
make_sft_parser(subparsers) make_sft_parser(subparsers)
make_vllm_serve_parser(subparsers) make_vllm_serve_parser(subparsers)
@ -111,6 +113,15 @@ def main():
args.training_script_args = sys.argv[2:] # remove "trl" and "kto" args.training_script_args = sys.argv[2:] # remove "trl" and "kto"
launch_command(args) # launch training launch_command(args) # launch training
elif args.command == "reward":
# Get the default args for the launch command
reward_training_script = resources.files("trl.scripts").joinpath("reward.py")
args = launch_command_parser().parse_args([str(reward_training_script)])
# Feed the args to the launch command
args.training_script_args = sys.argv[2:] # remove "trl" and "reward"
launch_command(args) # launch training
elif args.command == "rloo": elif args.command == "rloo":
# Get the default args for the launch command # Get the default args for the launch command
rloo_training_script = resources.files("trl.scripts").joinpath("rloo.py") rloo_training_script = resources.files("trl.scripts").joinpath("rloo.py")

View File

@ -94,11 +94,8 @@ def setup_chat_format(
Setup chat format by adding special tokens to the tokenizer, setting the correct format, and extending the Setup chat format by adding special tokens to the tokenizer, setting the correct format, and extending the
embedding layer of the model based on the new special tokens. embedding layer of the model based on the new special tokens.
<Tip warning="true"> > [!WARNING]
> This function is deprecated and will be removed in version 0.26.0. Please use [`clone_chat_template`] instead.
This function is deprecated and will be removed in version 0.26.0. Please use [`clone_chat_template`] instead.
</Tip>
If the model already has a chat template, this will throw an error. If you want to overwrite it, please set If the model already has a chat template, this will throw an error. If you want to overwrite it, please set
`tokenizer.chat_template` to `None`. `tokenizer.chat_template` to `None`.

109
trl/scripts/reward.py Normal file
View File

@ -0,0 +1,109 @@
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# /// script
# dependencies = [
# "trl",
# "peft",
# "trackio",
# "kernels",
# ]
# ///
import argparse
import os
from typing import Optional
from accelerate import logging
from datasets import load_dataset
from trl import (
DatasetMixtureConfig,
ModelConfig,
RewardConfig,
RewardTrainer,
ScriptArguments,
TrlParser,
get_dataset,
get_peft_config,
)
logger = logging.get_logger(__name__)
# Enable logging in a Hugging Face Space
os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio")
def main(script_args, training_args, model_args, dataset_args):
# Load the dataset
if dataset_args.datasets and script_args.dataset_name:
logger.warning(
"Both `datasets` and `dataset_name` are provided. The `datasets` argument will be used to load the "
"dataset and `dataset_name` will be ignored."
)
dataset = get_dataset(dataset_args)
elif dataset_args.datasets and not script_args.dataset_name:
dataset = get_dataset(dataset_args)
elif not dataset_args.datasets and script_args.dataset_name:
dataset = load_dataset(
script_args.dataset_name, name=script_args.dataset_config, streaming=script_args.dataset_streaming
)
else:
raise ValueError("Either `datasets` or `dataset_name` must be provided.")
# Initialize the RewardTrainer
trainer = RewardTrainer(
model=model_args.model_name_or_path,
args=training_args,
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
peft_config=get_peft_config(model_args),
)
# Train the model
trainer.train()
# Log training complete
trainer.accelerator.print("✅ Training completed.")
# Save and push to Hub
trainer.save_model(training_args.output_dir)
trainer.accelerator.print(f"💾 Model saved to {training_args.output_dir}.")
if training_args.push_to_hub:
trainer.push_to_hub(dataset_name=script_args.dataset_name)
trainer.accelerator.print(f"🤗 Model pushed to the Hub in https://huggingface.co/{trainer.hub_model_id}.")
def make_parser(subparsers: Optional[argparse._SubParsersAction] = None):
dataclass_types = (ScriptArguments, RewardConfig, ModelConfig, DatasetMixtureConfig)
if subparsers is not None:
parser = subparsers.add_parser(
"reward", help="Run the reward training script", dataclass_types=dataclass_types
)
else:
parser = TrlParser(dataclass_types)
return parser
if __name__ == "__main__":
parser = make_parser()
# When using the trl cli, this script may be run with additional arguments, corresponding accelerate arguments.
# To ensure that their parsing does not interfere with the script arguments, parse the arguments with
# `return_remaining_strings=True`, then ignore the remaining strings.
script_args, training_args, model_args, dataset_args, _ = parser.parse_args_and_config(
return_remaining_strings=True
)
main(script_args, training_args, model_args, dataset_args)

View File

@ -0,0 +1,55 @@
---
{{ card_data }}
---
# Model Card for {{ model_name }}
This model is a fine-tuned version of [{{ base_model }}](https://huggingface.co/{{ base_model }}){% if dataset_name %} on the [{{ dataset_name }}](https://huggingface.co/datasets/{{ dataset_name }}) dataset{% endif %}.
It has been trained using [TRL](https://github.com/huggingface/trl).
## Quick start
```python
from transformers import pipeline
text = "The capital of France is Paris."
rewarder = pipeline(model="{{ hub_model_id }}", device="cuda")
output = rewarder(text)[0]
print(output["score"])
```
## Training procedure
{% if wandb_url %}[<img src="https://raw.githubusercontent.com/wandb/assets/main/wandb-github-badge-28.svg" alt="Visualize in Weights & Biases" width="150" height="24"/>]({{ wandb_url }}){% endif %}
{% if comet_url %}[<img src="https://raw.githubusercontent.com/comet-ml/comet-examples/master/logo/comet_badge.png" alt="Visualize in Comet" width="135" height="20"/>]({{ comet_url }}){% endif %}
This model was trained with {{ trainer_name }}{% if paper_id %}, a method introduced in [{{ paper_title }}](https://huggingface.co/papers/{{ paper_id }}){% endif %}.
### Framework versions
- TRL: {{ trl_version }}
- Transformers: {{ transformers_version }}
- Pytorch: {{ pytorch_version }}
- Datasets: {{ datasets_version }}
- Tokenizers: {{ tokenizers_version }}
## Citations
{% if trainer_citation %}Cite {{ trainer_name }} as:
```bibtex
{{ trainer_citation }}
```{% endif %}
Cite TRL as:
```bibtex
{% raw %}@misc{vonwerra2022trl,
title = {{TRL: Transformer Reinforcement Learning}},
author = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang and Kashif Rasul and Quentin Gallou{\'e}dec},
year = 2020,
journal = {GitHub repository},
publisher = {GitHub},
howpublished = {\url{https://github.com/huggingface/trl}}
}{% endraw %}
```

View File

@ -28,6 +28,7 @@ class BaseTrainer(Trainer):
_tag_names = [] _tag_names = []
_name = "Base" _name = "Base"
_paper = {} _paper = {}
_template_file = None
def create_model_card( def create_model_card(
self, self,
@ -78,6 +79,7 @@ class BaseTrainer(Trainer):
comet_url=get_comet_experiment_url(), comet_url=get_comet_experiment_url(),
trainer_name=self._name, trainer_name=self._name,
trainer_citation=self._paper.get("citation"), trainer_citation=self._paper.get("citation"),
template_file=self._template_file,
paper_title=self._paper.get("title"), paper_title=self._paper.get("title"),
paper_id=self._paper.get("id"), paper_id=self._paper.get("id"),
) )

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional from typing import Any, Optional
from transformers import TrainingArguments from transformers import TrainingArguments
@ -32,22 +32,53 @@ class RewardConfig(TrainingArguments):
command line. command line.
Parameters: Parameters:
max_length (`int` or `None`, *optional*, defaults to `1024`): > Parameters that control the model
Maximum length of the sequences (prompt + completion) in the batch, filters out entries that exceed the
limit. This argument is required if you want to use the default data collator. model_init_kwargs (`dict[str, Any]`, *optional*):
Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model`
argument of the [`RewardTrainer`] is provided as a string. If you're training a MoE architecture and want
to include the load balancing/auxilliary loss as a part of the final loss, remember to set
`output_router_logits=True` in this dictionary.
chat_template_path (`str`, *optional*):
If specified, sets the model's chat template. This can either be the path to a tokenizer (local directory
or Hugging Face Hub model) or a direct path to a Jinja template file. When using a Jinja file, you must
ensure that any special tokens referenced in the template are added to the tokenizer and that the model's
embedding layer is resized accordingly.
disable_dropout (`bool`, *optional*, defaults to `True`): disable_dropout (`bool`, *optional*, defaults to `True`):
Whether to disable dropout in the model. Whether to disable dropout in the model.
> Parameters that control the data preprocessing
dataset_num_proc (`int`, *optional*): dataset_num_proc (`int`, *optional*):
Number of processes to use for processing the dataset. Number of processes to use for processing the dataset.
eos_token (`str`, *optional*):
Token used to indicate the end of a turn or sequence. If `None`, it defaults to
`processing_class.eos_token`.
pad_token (`str`, *optional*):
Token used for padding. If `None`, it defaults to `processing_class.pad_token`, or if that is also `None`,
it falls back to `processing_class.eos_token`.
max_length (`int` or `None`, *optional*, defaults to `1024`):
Maximum length of the tokenized sequence. Samples are filtered out if either chosen or rejected sequence
exceeds this value. If `None`, no filtering is applied.
pad_to_multiple_of (`int`, *optional*):
If set, the sequences will be padded to a multiple of this value.
> Parameters that control the training
center_rewards_coefficient (`float`, *optional*): center_rewards_coefficient (`float`, *optional*):
Coefficient to incentivize the reward model to output mean-zero rewards (proposed by Coefficient to incentivize the reward model to output mean-zero rewards (proposed by
https://huggingface.co/papers/2312.09244, Eq. 2). Recommended value: `0.01`. https://huggingface.co/papers/2312.09244, Eq. 2). Recommended value: `0.01`.
remove_unused_columns (`bool`, *optional*, defaults to `False`): activation_offloading (`bool`, *optional*, defaults to `False`):
Whether to remove the columns that are not used by the model's forward pass. Can be `True` only if the Whether to offload the activations to the CPU.
dataset is pretokenized.
""" """
_VALID_DICT_FIELDS = TrainingArguments._VALID_DICT_FIELDS + ["model_init_kwargs"]
# Parameters whose default values are overridden from TrainingArguments # Parameters whose default values are overridden from TrainingArguments
learning_rate: float = field(
default=1e-4,
metadata={"help": "The initial learning rate for AdamW."},
)
logging_steps: float = field( logging_steps: float = field(
default=10, default=10,
metadata={ metadata={
@ -70,21 +101,59 @@ class RewardConfig(TrainingArguments):
}, },
) )
max_length: Optional[int] = field( # Parameters that control the model
default=1024, model_init_kwargs: Optional[dict[str, Any]] = field(
default=None,
metadata={ metadata={
"help": "Maximum length of the sequences (prompt + completion) in the batch, filters out entries that " "help": "Keyword arguments for `AutoModelForCausalLM.from_pretrained`, used when the `model` argument of "
"exceed the limit. This argument is required if you want to use the default data collator." "the `RewardTrainer` is provided as a string."
},
)
chat_template_path: Optional[str] = field(
default=None,
metadata={
"help": "If specified, sets the model's chat template. This can either be the path to a tokenizer (local "
"directory or Hugging Face Hub model) or a direct path to a Jinja template file. When using a Jinja file, "
"you must ensure that any special tokens referenced in the template are added to the tokenizer and "
"that the model's embedding layer is resized accordingly."
}, },
) )
disable_dropout: bool = field( disable_dropout: bool = field(
default=True, default=True,
metadata={"help": "Whether to disable dropout in the model and reference model."}, metadata={"help": "Whether to disable dropout in the model."},
) )
# Parameters that control the data preprocessing
dataset_num_proc: Optional[int] = field( dataset_num_proc: Optional[int] = field(
default=None, default=None,
metadata={"help": "Number of processes to use for processing the dataset."}, metadata={"help": "Number of processes to use for processing the dataset."},
) )
eos_token: Optional[str] = field(
default=None,
metadata={
"help": "Token used to indicate the end of a turn or sequence. If `None`, it defaults to `processing_class.eos_token`."
},
)
pad_token: Optional[str] = field(
default=None,
metadata={
"help": "Token used for padding. If `None`, it defaults to `processing_class.pad_token`, or if that "
"is also `None`, it falls back to `processing_class.eos_token`."
},
)
max_length: Optional[int] = field(
default=1024,
metadata={
"help": "Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated from"
"the right. If `None`, no truncation is applied."
},
)
pad_to_multiple_of: Optional[int] = field(
default=None,
metadata={"help": "If set, the sequences will be padded to a multiple of this value."},
)
# Parameters that control the training
center_rewards_coefficient: Optional[float] = field( center_rewards_coefficient: Optional[float] = field(
default=None, default=None,
metadata={ metadata={
@ -92,15 +161,11 @@ class RewardConfig(TrainingArguments):
"https://huggingface.co/papers/2312.09244, Eq. 2). Recommended value: `0.01`." "https://huggingface.co/papers/2312.09244, Eq. 2). Recommended value: `0.01`."
}, },
) )
remove_unused_columns: bool = field( activation_offloading: bool = field(
default=False, default=False,
metadata={ metadata={"help": "Whether to offload the activations to the CPU."},
"help": "Whether to remove the columns that are not used by the model's forward pass. Can be `True` only "
"if the dataset is pretokenized."
},
) )
def __post_init__(self): def __post_init__(self):
self.bf16 = not (self.fp16) if self.bf16 is None else self.bf16 self.bf16 = not (self.fp16) if self.bf16 is None else self.bf16
super().__post_init__() super().__post_init__()

View File

@ -12,132 +12,348 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import contextlib
import logging
import os
import re
from collections import defaultdict from collections import defaultdict
from dataclasses import FrozenInstanceError, replace from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Optional, Union from typing import Any, Callable, Optional, Union
import pandas as pd
import torch import torch
import torch.nn as nn import torch.nn as nn
from accelerate import PartialState, logging import transformers
from accelerate.utils import gather_object from accelerate import PartialState
from datasets import Dataset from accelerate.logging import get_logger
from datasets import Dataset, IterableDataset
from transformers import ( from transformers import (
BaseImageProcessor, AutoModelForSequenceClassification,
AutoTokenizer,
DataCollator, DataCollator,
FeatureExtractionMixin,
PreTrainedModel, PreTrainedModel,
PreTrainedTokenizerBase, PreTrainedTokenizerBase,
ProcessorMixin,
) )
from transformers.data.data_collator import DataCollatorMixin
from transformers.trainer_callback import TrainerCallback from transformers.trainer_callback import TrainerCallback
from transformers.trainer_pt_utils import nested_detach
from transformers.trainer_utils import EvalPrediction from transformers.trainer_utils import EvalPrediction
from transformers.utils import is_peft_available, is_rich_available from transformers.utils import is_peft_available
from ..data_utils import maybe_apply_chat_template from ..data_utils import is_conversational
from ..models import prepare_peft_model from ..models import clone_chat_template, get_act_offloading_ctx_manager, prepare_peft_model
from .base_trainer import BaseTrainer from .base_trainer import BaseTrainer
from .reward_config import RewardConfig from .reward_config import RewardConfig
from .utils import ( from .utils import disable_dropout_in_model, pad, remove_none_values
RewardDataCollatorWithPadding,
compute_accuracy,
decode_and_strip_padding,
disable_dropout_in_model,
log_table_to_comet_experiment,
print_rich_table,
)
if is_peft_available(): if is_peft_available():
from peft import PeftModel from peft import PeftConfig, PeftModel
logger = logging.get_logger(__name__) logger = get_logger(__name__)
def _tokenize(batch: dict[str, list[Any]], tokenizer: "PreTrainedTokenizerBase") -> dict[str, list[Any]]: # AutoModelForSequenceClassification adds a new classification head when loading a CausalLM. That head is randomly
"""Tokenize a batch from a reward modelling dataset.""" # initialized and triggers a harmless warning about uninitialized weights. We suppress just that specific warning to
new_examples = { # avoid confusing users.
"input_ids_chosen": [], @contextmanager
"attention_mask_chosen": [], def suppress_from_pretrained_warning(logger: logging.Logger):
"input_ids_rejected": [], pattern = re.compile(
"attention_mask_rejected": [], r"^Some weights of \S+ were not initialized from the model checkpoint at \S+ and are newly initialized: "
} r"\[.*\]\nYou should probably TRAIN this model on a down-stream task to be able to use it for predictions and "
for chosen, rejected in zip(batch["chosen"], batch["rejected"]): r"inference\.$"
tokenized_chosen = tokenizer(chosen) )
tokenized_rejected = tokenizer(rejected)
new_examples["input_ids_chosen"].append(tokenized_chosen["input_ids"])
new_examples["attention_mask_chosen"].append(tokenized_chosen["attention_mask"])
new_examples["input_ids_rejected"].append(tokenized_rejected["input_ids"])
new_examples["attention_mask_rejected"].append(tokenized_rejected["attention_mask"])
return new_examples class _Filter(logging.Filter):
def filter(self, record: logging.LogRecord) -> bool:
return not pattern.search(record.getMessage())
f = _Filter()
logger.addFilter(f)
try:
yield
finally:
logger.removeFilter(f)
@dataclass
class DataCollatorForPreference(DataCollatorMixin):
"""
Data collator used for preference data. Inputs are dynamically padded to the maximum length of a batch.
This collator expects each example in the input list to be a dictionary containing the `"chosen_input_ids"` and
`"rejected_input_ids"` keys. The collator returns a dictionary containing the following keys:
- `"input_ids"`: Tensor of input IDs, padded to the maximum length of the batch. The first half of the batch
corresponds to the `"chosen_input_ids"` and the second half to the `"rejected_input_ids"`.
- `"attention_mask"`: Tensor of attention mask, padded to the maximum length of the batch.
Optionally, the examples can contain a `"margin"` key, in which case the returned dictionary will also contain a
`"margin"` key with a tensor of margins.
Args:
pad_token_id (`int`):
Token ID to use for padding.
pad_to_multiple_of (`int`, *optional*):
If set, the sequences will be padded to a multiple of this value.
return_tensors (`str`, *optional*, defaults to `"pt"`):
Type of Tensor to return. Only `"pt"` is currently supported.
Examples:
```python
>>> from trl.trainer.reward_trainer import DataCollatorForPreference
>>> collator = DataCollatorForPreference(pad_token_id=0)
>>> examples = [
... {"chosen_input_ids": [1, 2, 3], "rejected_input_ids": [4, 5]},
... {"chosen_input_ids": [6, 7], "rejected_input_ids": [8]},
... ]
>>> collator(examples)
{'input_ids': tensor([[1, 2, 3],
[6, 7, 0],
[4, 5, 0],
[8, 0, 0]]),
'attention_mask': tensor([[1, 1, 1],
[1, 1, 0],
[1, 1, 0],
[1, 0, 0]])}
>>> examples = [
... {"chosen_input_ids": [1, 2, 3], "rejected_input_ids": [4, 5], "margin": 0.5},
... {"chosen_input_ids": [6, 7], "rejected_input_ids": [8], "margin": 0.0},
... ]
>>> collator(examples)
{'input_ids': tensor([[1, 2, 3],
[6, 7, 0],
[4, 5, 0],
[8, 0, 0]]),
'attention_mask': tensor([[1, 1, 1],
[1, 1, 0],
[1, 1, 0],
[1, 0, 0]]),
'margin': tensor([0.5, 0.0])}
```
"""
pad_token_id: int
pad_to_multiple_of: Optional[int] = None
return_tensors: str = "pt"
def torch_call(self, examples: list[dict[str, Any]]) -> dict[str, Any]:
# Convert to tensor
chosen_input_ids = [torch.tensor(example["chosen_input_ids"]) for example in examples]
rejected_input_ids = [torch.tensor(example["rejected_input_ids"]) for example in examples]
if "margin" in examples[0]:
margins = torch.tensor([example["margin"] for example in examples], dtype=torch.float)
input_ids = chosen_input_ids + rejected_input_ids
attention_mask = [torch.ones_like(ids) for ids in input_ids]
output = {}
# Pad
output["input_ids"] = pad(
input_ids,
padding_value=self.pad_token_id,
padding_side="right",
pad_to_multiple_of=self.pad_to_multiple_of,
)
output["attention_mask"] = pad(
attention_mask,
padding_value=0,
padding_side="right",
pad_to_multiple_of=self.pad_to_multiple_of,
)
if "margin" in examples[0]:
output["margin"] = margins
return output
class RewardTrainer(BaseTrainer): class RewardTrainer(BaseTrainer):
""" """
Trainer for custom reward. Trainer for Outcome-supervised Reward Models (ORM).
This class is a wrapper around the [`~transformers.Trainer`] class and inherits all of its attributes and methods.
Example:
```python
from trl import RewardTrainer
from datasets import load_dataset
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
trainer = RewardTrainer(model="Qwen/Qwen2.5-0.5B-Instruct", train_dataset=dataset)
trainer.train()
```
Args: Args:
model ([`~transformers.PreTrainedModel`] or `torch.nn.Module`, *optional*): model (`Union[str, PreTrainedModel]`):
Model to be trained, preferably an [`~transformers.AutoModelForSequenceClassification`]. Model to be trained. Can be either:
- A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or a
path to a *directory* containing model weights saved using
[`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded
using `AutoModelForSequenceClassification.from_pretrained` with the keyword arguments in
`args.model_init_kwargs`.
- A sequence classification [`~transformers.PreTrainedModel`] object.
args ([`RewardConfig`], *optional*): args ([`RewardConfig`], *optional*):
Training arguments. Configuration for this trainer. If `None`, a default configuration is used.
data_collator ([`~transformers.DataCollator`], *optional*): data_collator ([`~transformers.DataCollator`], *optional*):
The data collator to use for training. If None is specified, the default data collator Function to use to form a batch from a list of elements of the processed `train_dataset` or `eval_dataset`.
[`~trainer.utils.RewardDataCollatorWithPadding`] will be used which will pad the sequences to the maximum Will default to [`~trainer.reward_trainer.DataCollatorForPreference`].
length of the sequences in the batch, given a dataset of paired sequences. train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
train_dataset ([`~datasets.Dataset`], *optional*): Dataset to use for training. This trainer supports [preference](#preference) type (both implicit and
The dataset to use for training. explicit prompt). The format of the samples can be either:
eval_dataset ([`~datasets.Dataset`], *optional*):
The dataset to use for evaluation. - [Standard](dataset_formats#standard): Each sample contains plain text.
processing_class ([`~transformers.PreTrainedTokenizerBase`], [`~transformers.BaseImageProcessor`], [`~transformers.FeatureExtractionMixin`] or [`~transformers.ProcessorMixin`], *optional*): - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
Processing class used to process the data. If provided, will be used to automatically process the inputs and content).
for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
reuse the fine-tuned model. The trainer also supports processed datasets (tokenized) as long as they contain an `chosen_input_ids` and
model_init (`Callable[[], transformers.PreTrainedModel]`, *optional*): `rejected_input_ids` fields.
The model initializer to use for training. If None is specified, the default model initializer will be eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
used. Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
compute_metrics (`Callable[[transformers.EvalPrediction], dict]`, *optional*, defaults to [`~trainer.utils.compute_accuracy`]): processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*):
Function to compute metrics at evaluation. Must take in an [`~transformers.EvalPrediction`] and return a Tokenizer used to process the data. If `None`, the tokenizer is loaded from the model's name with
dictionary string to float. [`~transformers.AutoTokenizer.from_pretrained`]. A padding token, `processing_class.pad_token`, must be
callbacks (`list` of [`~transformers.TrainerCallback`], *optional*): set. If the processing class has not set a padding token, `processing_class.eos_token` will be used as the
Callbacks to use during training. default.
optimizers (`tuple` of `torch.optim.Optimizer` and `torch.optim.lr_scheduler.LambdaLR`, *optional*, defaults to `(None, None)`): compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
Tuple containing the optimizer and the learning rate scheduler to use for training. The function that will be used to compute metrics at evaluation. Must take a
[`~transformers.EvalPrediction`] and return a dictionary string to metric values. When passing
[`RewardConfig`] with `batch_eval_metrics` set to `True`, your `compute_metrics` function must take a
boolean `compute_result` argument. This will be triggered after the last eval batch to signal that the
function needs to calculate and return the global summary statistics rather than accumulating the
batch-level statistics.
callbacks (list of [`~transformers.TrainerCallback`], *optional*):
List of callbacks to customize the training loop. Will add those to the list of default callbacks detailed
in [here](https://huggingface.co/docs/transformers/main_classes/callback).
If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`]
method.
optimizers (`tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]]`, *optional*, defaults to `(None, None)`):
A tuple containing the optimizer and the scheduler to use. Will default to an instance of `AdamW` on your
model and a scheduler given by [`~transformers.get_linear_schedule_with_warmup`] controlled by `args`.
optimizer_cls_and_kwargs (`tuple[Type[torch.optim.Optimizer], Dict[str, Any]]`, *optional*):
A tuple containing the optimizer class and keyword arguments to use. Overrides `optim` and `optim_args` in
`args`. Incompatible with the `optimizers` argument.
Unlike `optimizers`, this argument avoids the need to place model parameters on the correct devices before
initializing the Trainer.
preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*): preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`, *optional*):
Function to preprocess the logits before computing the metrics. Must take in the `logits` and `labels` and A function that preprocess the logits right before caching them at each evaluation step. Must take two
return the logits to be used for metrics computation. tensors, the logits and the labels, and return the logits once processed as desired. The modifications made
peft_config (`dict`, *optional*): by this function will be reflected in the predictions received by `compute_metrics`.
PEFT configuration to use PEFT for training. If `None`, PEFT is not used. If provided, the `model` will be
wrapped with the specified PEFT adapter. Note that the labels (second parameter) will be `None` if the dataset does not have them.
peft_config ([`~peft.PeftConfig`], *optional*):
PEFT configuration used to wrap the model. If `None`, the model is not wrapped. Note that if the loaded
model is a causal LM, it's highly recommended to set `modules_to_save=["score"]` in the PEFT configuration
to ensure that the reward head is properly trained.
""" """
_tag_names = ["trl", "reward-trainer"] _tag_names = ["trl", "reward-trainer"]
_name = "Reward" _name = "Reward"
_template_file = "rm_model_card.md"
def __init__( def __init__(
self, self,
model: Optional[Union[PreTrainedModel, nn.Module]] = None, model: Union[str, PreTrainedModel],
args: Optional[RewardConfig] = None, args: Optional[RewardConfig] = None,
data_collator: Optional[DataCollator] = None, data_collator: Optional[DataCollator] = None,
train_dataset: Optional[Dataset] = None, train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None, eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
processing_class: Optional[ processing_class: Optional[PreTrainedTokenizerBase] = None,
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
] = None,
model_init: Optional[Callable[[], PreTrainedModel]] = None,
compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None, compute_metrics: Optional[Callable[[EvalPrediction], dict]] = None,
callbacks: Optional[list[TrainerCallback]] = None, callbacks: Optional[list[TrainerCallback]] = None,
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = ( optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
None, optimizer_cls_and_kwargs: Optional[tuple[type[torch.optim.Optimizer], dict[str, Any]]] = None,
None,
),
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
peft_config: Optional[dict] = None, peft_config: Optional["PeftConfig"] = None,
): ):
# Args
if args is None:
model_name = model if isinstance(model, str) else model.config._name_or_path
model_name = model_name.split("/")[-1]
args = RewardConfig(f"{model_name}-Reward")
# Model
model_init_kwargs = args.model_init_kwargs or {}
if isinstance(model, str):
model_id = model
dtype = model_init_kwargs.get("dtype")
if isinstance(dtype, torch.dtype) or dtype == "auto" or dtype is None:
pass # dtype is already a torch.dtype or "auto" or None
elif isinstance(dtype, str) and dtype in ["bfloat16", "float16", "float32"]:
model_init_kwargs["dtype"] = getattr(torch, dtype)
else:
raise ValueError(
"Invalid `dtype` passed to `RewardConfig`. Expected either 'auto' or a string representing "
f"a valid `torch.dtype` (e.g., 'float32'), but got {dtype}."
)
with suppress_from_pretrained_warning(transformers.modeling_utils.logger):
model = AutoModelForSequenceClassification.from_pretrained(model_id, num_labels=1, **model_init_kwargs)
else:
model_id = model.config._name_or_path
if args.model_init_kwargs is not None:
logger.warning(
"You passed `model_init_kwargs` to the `RewardConfig`, but your model is already instantiated. "
"The `model_init_kwargs` will be ignored."
)
# Processing class
if processing_class is None:
processing_class = AutoTokenizer.from_pretrained(model_id)
# Handle pad token for processors or tokenizers
if args.eos_token is not None:
eos_token = args.eos_token
eos_token_id = processing_class.convert_tokens_to_ids(eos_token)
if eos_token_id is None:
raise ValueError(
f"The specified `eos_token` ('{eos_token}') is not found in the vocabulary of the given "
f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `eos_token` exists "
"in the vocabulary before using it as an EOS token."
)
processing_class.eos_token_id = eos_token_id
if args.chat_template_path is not None:
if os.path.isfile(args.chat_template_path) and args.chat_template_path.endswith((".jinja", ".j2")):
with open(args.chat_template_path, encoding="utf-8") as chat_template_file:
processing_class.chat_template = chat_template_file.read()
added_tokens = []
else:
model, processing_class, added_tokens = clone_chat_template(
model, processing_class, args.chat_template_path
)
else:
added_tokens = []
# PEFT configuration and model wrapping
if peft_config is not None:
if added_tokens:
# Ensure that the added tokens are trainable
if peft_config.trainable_token_indices is None:
peft_config.trainable_token_indices = {"embed_tokens": added_tokens}
elif "embed_tokens" not in peft_config.trainable_token_indices:
peft_config.trainable_token_indices["embed_tokens"] = added_tokens
else:
peft_config.trainable_token_indices["embed_tokens"].extend(added_tokens)
# Ensure that the lm_head is trainable
if peft_config.modules_to_save is None or "lm_head" not in peft_config.modules_to_save:
logger.warning(
"Cloning chat template added new tokens to the tokenizer, but 'lm_head' is not in PEFT's "
"`modules_to_save`. As a result, the model may not learn to generate outputs with these new "
"tokens, leading to degraded generation quality. To fix this, add "
"`modules_to_save=['lm_head']` to your PEFT configuration."
)
if peft_config.modules_to_save is None:
peft_config.modules_to_save = ["lm_head"]
else:
peft_config.modules_to_save.append("lm_head")
if peft_config is not None or (is_peft_available() and isinstance(model, PeftModel)): if peft_config is not None or (is_peft_available() and isinstance(model, PeftModel)):
model = prepare_peft_model(model, peft_config, args) model = prepare_peft_model(model, peft_config, args)
@ -145,78 +361,47 @@ class RewardTrainer(BaseTrainer):
if args.disable_dropout: if args.disable_dropout:
disable_dropout_in_model(model) disable_dropout_in_model(model)
if compute_metrics is None: # Pad token (needed for SequenceClassification models)
compute_metrics = compute_accuracy # If not provided, use the one from the processing class or the eos token if the processing class does not have
# a pad token.
pad_token = args.pad_token or processing_class.pad_token or processing_class.eos_token
pad_token_id = processing_class.convert_tokens_to_ids(pad_token)
if pad_token_id is None:
raise ValueError(
f"The specified `pad_token` ('{pad_token}') is not found in the vocabulary of the given "
f"`processing_class` ({processing_class.__class__.__name__}). Ensure that the `pad_token` exists "
"in the vocabulary before using it as a padding token."
)
model.config.pad_token_id = pad_token_id
processing_class.pad_token_id = pad_token_id
# Data collator
if data_collator is None: if data_collator is None:
if processing_class is None: data_collator = DataCollatorForPreference(
raise ValueError( pad_token_id=pad_token_id,
"A processing_class must be specified when using the default RewardDataCollatorWithPadding" pad_to_multiple_of=args.pad_to_multiple_of,
) )
max_length = args.max_length # Dataset
train_dataset = self._prepare_dataset(train_dataset, processing_class, args, "train")
if eval_dataset is not None:
if isinstance(eval_dataset, dict):
eval_dataset = {
key: self._prepare_dataset(dataset, processing_class, args, key)
for key, dataset in eval_dataset.items()
}
else:
eval_dataset = self._prepare_dataset(eval_dataset, processing_class, args, "eval")
data_collator = RewardDataCollatorWithPadding(processing_class) # Initialize the metrics
self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)}
self._total_train_tokens = 0
if args.remove_unused_columns: # Initialize the Trainer. Parent class will handle:
try: # for bc before https://github.com/huggingface/transformers/pull/25435 # - DeepSpeed configuration (through create_accelerator_and_postprocess)
args.remove_unused_columns = False # - FSDP setup
except FrozenInstanceError: # - Distributed training setup
args = replace(args, remove_unused_columns=False) # - Optimizer and scheduler creation
# warn users
logger.warning(
"When using RewardDataCollatorWithPadding, you should set `remove_unused_columns=False` in your RewardConfig"
" we have set it for you, but you should do it yourself in the future.",
)
self.use_reward_data_collator = True
else:
self.use_reward_data_collator = False
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
# input tensor associated with the key "input_ids". However, in Reward, the sampled data does not include the
# "input_ids" key. Instead, the available keys are "input_ids_chosen" and "input_ids_rejected". As a result,
# the trainer issues the warning: "Could not estimate the number of tokens of the input, floating-point
# operations will not be computed." To suppress this warning, we set the "estimate_tokens" key in the model's
# "warnings_issued" dictionary to True. This acts as a flag to indicate that the warning has already been
# issued.
model.warnings_issued["estimate_tokens"] = True
if "input_ids_chosen" not in train_dataset.column_names:
with PartialState().main_process_first():
fn_kwargs = {"tokenizer": processing_class}
train_dataset = train_dataset.map(maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class})
train_dataset = train_dataset.map(
_tokenize,
batched=True,
fn_kwargs=fn_kwargs,
num_proc=args.dataset_num_proc,
)
# This filter is important because otherwise you get samples that exceed the model's context length and
# get truncated => noisy signal the chosen/rejected label gets lost. The downside is that the
# user might get surprised if N samples are missing from training.
train_dataset = train_dataset.filter(
lambda x: len(x["input_ids_chosen"]) <= max_length and len(x["input_ids_rejected"]) <= max_length,
num_proc=args.dataset_num_proc,
)
if eval_dataset is not None:
eval_dataset = eval_dataset.map(
maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class}
)
eval_dataset = eval_dataset.map(
_tokenize,
fn_kwargs=fn_kwargs,
batched=True,
num_proc=args.dataset_num_proc,
)
# This filter is important because otherwise you get samples that exceed the model's context length and
# get truncated => noisy signal the chosen/rejected label gets lost. The downside is that the
# user might get surprised if N samples are missing from training.
eval_dataset = eval_dataset.filter(
lambda x: len(x["input_ids_chosen"]) <= max_length
and len(x["input_ids_rejected"]) <= max_length,
num_proc=args.dataset_num_proc,
)
super().__init__( super().__init__(
model=model, model=model,
@ -225,35 +410,140 @@ class RewardTrainer(BaseTrainer):
train_dataset=train_dataset, train_dataset=train_dataset,
eval_dataset=eval_dataset, eval_dataset=eval_dataset,
processing_class=processing_class, processing_class=processing_class,
model_init=model_init,
compute_metrics=compute_metrics, compute_metrics=compute_metrics,
callbacks=callbacks, callbacks=callbacks,
optimizers=optimizers, optimizers=optimizers,
optimizer_cls_and_kwargs=optimizer_cls_and_kwargs,
preprocess_logits_for_metrics=preprocess_logits_for_metrics, preprocess_logits_for_metrics=preprocess_logits_for_metrics,
) )
# During evaluation, Trainer calls compute_loss() only if can_return_loss is True and label_names is empty.
self.can_return_loss = True
self.label_names = []
# Initialize activation offloading context
if self.args.activation_offloading:
self.maybe_activation_offload_context = get_act_offloading_ctx_manager(model=self.model)
else:
self.maybe_activation_offload_context = contextlib.nullcontext()
# Add tags for models that have been loaded with the correct transformers version # Add tags for models that have been loaded with the correct transformers version
if hasattr(self.model, "add_model_tags"): if hasattr(self.model, "add_model_tags"):
self.model.add_model_tags(self._tag_names) self.model.add_model_tags(self._tag_names)
self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
def _prepare_dataset(
self,
dataset: Union[Dataset, IterableDataset],
processing_class: PreTrainedTokenizerBase,
args: RewardConfig,
dataset_name: str,
) -> Union[Dataset, IterableDataset]:
# Tabular backends like Arrow/Parquet insert `None` for mismatched keys in nested structures. Clean them from
# sampled data.
if isinstance(dataset, Dataset): # IterableDataset does not support `with_transform`
dataset = dataset.with_transform(remove_none_values)
# If the dataset is already preprocessed (tokenized), skip the processing steps.
column_names = list(next(iter(dataset)).keys())
is_processed = "chosen_input_ids" in column_names and "rejected_input_ids" in column_names
# Build the kwargs for the `map` function
map_kwargs = {}
if isinstance(dataset, Dataset): # IterableDataset does not support num_proc
map_kwargs["num_proc"] = args.dataset_num_proc
with PartialState().main_process_first():
if not is_processed:
# Add EOS token to the end of the sequences if needed
first_example = next(iter(dataset))
if not is_conversational(first_example):
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Adding EOS to {dataset_name} dataset"
def add_eos(example, eos_token):
if not example["chosen"].endswith(eos_token):
example["chosen"] = example["chosen"] + eos_token
if "rejected" in example and not example["rejected"].endswith(eos_token):
example["rejected"] = example["rejected"] + eos_token
return example
dataset = dataset.map(
add_eos,
fn_kwargs={"eos_token": processing_class.eos_token},
**map_kwargs,
)
# Tokenize the dataset
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset"
def tokenize_fn(example, processing_class):
if "prompt" in example: # explicit prompt case
example["chosen"] = example["prompt"] + example["chosen"]
example["rejected"] = example["prompt"] + example["rejected"]
if is_conversational(example):
chosen_input_ids = processing_class.apply_chat_template(
example["chosen"],
tools=example.get("tools"),
**example.get("chat_template_kwargs", {}),
)
rejected_input_ids = processing_class.apply_chat_template(
example["rejected"],
tools=example.get("tools"),
**example.get("chat_template_kwargs", {}),
)
output = {"chosen_input_ids": chosen_input_ids, "rejected_input_ids": rejected_input_ids}
else:
output = {
"chosen_input_ids": processing_class(text=example["chosen"])["input_ids"],
"rejected_input_ids": processing_class(text=example["rejected"])["input_ids"],
}
return output
dataset = dataset.map(tokenize_fn, fn_kwargs={"processing_class": processing_class}, **map_kwargs)
# Filter samples that are longer than `max_length`
if args.max_length is not None:
if isinstance(dataset, Dataset): # `IterableDataset.map` does not support `desc`
map_kwargs["desc"] = f"Filtering {dataset_name} >{args.max_length} tokens"
dataset = dataset.filter(
lambda example: len(example["chosen_input_ids"]) <= args.max_length
and len(example["rejected_input_ids"]) <= args.max_length,
**map_kwargs,
)
return dataset
def _set_signature_columns_if_needed(self):
# If `self.args.remove_unused_columns` is True, non-signature columns are removed.
# By default, this method sets `self._signature_columns` to the model's expected inputs (usually, "input_ids"
# and "attention_mask").
if self._signature_columns is None:
self._signature_columns = ["chosen_input_ids", "rejected_input_ids", "margin"]
def compute_loss( def compute_loss(
self, self,
model: Union[PreTrainedModel, nn.Module], model: nn.Module,
inputs: dict[str, Union[torch.Tensor, Any]], inputs: dict[str, Union[torch.Tensor, Any]],
return_outputs=False, return_outputs: bool = False,
num_items_in_batch=None, num_items_in_batch: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, dict[str, torch.Tensor]]]: ):
rewards_chosen = model( """
input_ids=inputs["input_ids_chosen"], Compute training loss and additionally compute token accuracies
attention_mask=inputs["attention_mask_chosen"], """
return_dict=True, mode = "train" if self.model.training else "eval"
)["logits"]
rewards_rejected = model( # If not set, defaults from model config and may warn since cache isn't compatible with gradient checkpointing
input_ids=inputs["input_ids_rejected"], inputs["use_cache"] = False
attention_mask=inputs["attention_mask_rejected"], outputs = model(**inputs)
return_dict=True,
)["logits"] # Split the rewards into chosen and rejected
# calculate loss, optionally modulate with margin rewards_chosen, rewards_rejected = torch.chunk(outputs.logits.squeeze(-1), chunks=2)
# Calculate loss, optionally modulate with margin
if "margin" in inputs: if "margin" in inputs:
loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected - inputs["margin"]).mean() loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected - inputs["margin"]).mean()
else: else:
@ -262,86 +552,45 @@ class RewardTrainer(BaseTrainer):
if self.args.center_rewards_coefficient is not None: if self.args.center_rewards_coefficient is not None:
loss += self.args.center_rewards_coefficient * torch.mean((rewards_chosen + rewards_rejected) ** 2) loss += self.args.center_rewards_coefficient * torch.mean((rewards_chosen + rewards_rejected) ** 2)
if return_outputs: if mode == "train":
return loss, { num_tokens_in_batch = self.accelerator.gather_for_metrics(inputs["attention_mask"].sum()).sum().item()
"rewards_chosen": rewards_chosen, self._total_train_tokens += num_tokens_in_batch
"rewards_rejected": rewards_rejected, self._metrics[mode]["num_tokens"] = [self._total_train_tokens]
}
return loss
def prediction_step(
self,
model: Union[PreTrainedModel, nn.Module],
inputs: dict[str, Union[torch.Tensor, Any]],
prediction_loss_only: bool,
ignore_keys: Optional[list[str]] = None,
) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
inputs = self._prepare_inputs(inputs)
if ignore_keys is None:
if hasattr(self.model, "config"):
ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
else:
ignore_keys = []
# Compute min, mean, max, accuracy and margin
with torch.no_grad(): with torch.no_grad():
loss, logits_dict = self.compute_loss(model, inputs, return_outputs=True) all_rewards = self.accelerator.gather(outputs.logits)
self._metrics[mode]["min_reward"].append(all_rewards.min().item())
self._metrics[mode]["mean_reward"].append(all_rewards.mean().item())
self._metrics[mode]["max_reward"].append(all_rewards.max().item())
if prediction_loss_only: mean_accuracy = (rewards_chosen > rewards_rejected).float().mean()
return (loss, None, None) mean_accuracy = self.accelerator.gather_for_metrics(mean_accuracy).mean().item()
self._metrics[mode]["accuracy"].append(mean_accuracy)
loss = loss.detach() mean_margin = (rewards_chosen - rewards_rejected).mean()
logits = tuple(v for k, v in logits_dict.items() if k not in ignore_keys) mean_margin = self.accelerator.gather_for_metrics(mean_margin).mean()
logits = nested_detach(logits) self._metrics[mode]["margin"].append(mean_margin.item())
# Stack accepted against rejected, mean over logits
# and softmax to get preferences between accepted and rejected to sum to 1
logits = torch.stack(logits).mean(dim=2).softmax(dim=0).T
labels = torch.zeros(logits.shape[0]) return (loss, outputs) if return_outputs else loss
labels = self._prepare_inputs(labels)
return loss, logits, labels # Override training step to add activation offloading context.
def training_step(self, *args, **kwargs):
with self.maybe_activation_offload_context:
return super().training_step(*args, **kwargs)
def evaluate(self, *args, **kwargs): def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
num_print_samples = kwargs.pop("num_print_samples", 4) mode = "train" if self.model.training else "eval"
self.visualize_samples(num_print_samples) metrics = {key: sum(val) / len(val) for key, val in self._metrics[mode].items()} # average the metrics
return super().evaluate(*args, **kwargs)
def visualize_samples(self, num_print_samples: int): # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
""" # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
Visualize the reward model logits prediction if mode == "eval":
metrics = {f"eval_{key}": val for key, val in metrics.items()}
Args: logs.update(metrics)
num_print_samples (`int`, defaults to `4`): super().log(logs, start_time)
The number of samples to print. Set to `-1` to print all samples. self._metrics[mode].clear()
"""
eval_dataloader = self.get_eval_dataloader()
table = defaultdict(list)
for _, inputs in enumerate(eval_dataloader):
_, logits, _ = self.prediction_step(self.model, inputs, prediction_loss_only=False)
chosen_text = decode_and_strip_padding(inputs["input_ids_chosen"], self.processing_class)
rejected_text = decode_and_strip_padding(inputs["input_ids_rejected"], self.processing_class)
table["chosen_text"].extend(gather_object(chosen_text))
table["rejected_text"].extend(gather_object(rejected_text))
table["logits"].extend(
gather_object([[round(inner_item, 4) for inner_item in item] for item in logits.tolist()])
)
if num_print_samples >= 0 and len(table["chosen_text"]) >= num_print_samples:
break
df = pd.DataFrame(table)
if self.accelerator.process_index == 0:
if is_rich_available():
print_rich_table(df[:num_print_samples])
if "wandb" in self.args.report_to:
import wandb
if wandb.run is not None:
wandb.log({"completions": wandb.Table(dataframe=df)})
if "comet_ml" in self.args.report_to:
log_table_to_comet_experiment(
name="completions.csv",
table=df,
)
# Ensure the model card is saved along with the checkpoint # Ensure the model card is saved along with the checkpoint
def _save_checkpoint(self, model, trial): def _save_checkpoint(self, model, trial):

View File

@ -15,10 +15,9 @@
import contextlib import contextlib
import os import os
from collections import defaultdict from collections import defaultdict
from collections.abc import Mapping
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Optional, TypeVar, Union from typing import Any, Callable, Optional, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -56,6 +55,7 @@ from .utils import (
entropy_from_logits, entropy_from_logits,
flush_left, flush_left,
pad, pad,
remove_none_values,
selective_log_softmax, selective_log_softmax,
) )
@ -66,7 +66,6 @@ if is_peft_available():
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
TListOrMapping = TypeVar("TListOrMapping", list, Mapping)
FLASH_ATTENTION_VARIANTS = { FLASH_ATTENTION_VARIANTS = {
"flash_attention_2", "flash_attention_2",
@ -77,38 +76,6 @@ FLASH_ATTENTION_VARIANTS = {
} }
def remove_none_values(example: TListOrMapping) -> TListOrMapping:
"""
Recursively removes entries with `None` values from a nested structure (list or dictionary).
Args:
example (`list` or `Mapping`):
Input nested structure (list or dictionary) from which to remove `None`.
Example:
```python
>>> [
... {
... "a": {"aa": None, "ab": 1},
... "b": "my_string",
... }
... ]
>>> remove_none_values(example)
[{'a': {'ab': 1}, 'b': 'my_string'}]
```
"""
if isinstance(example, list):
return [remove_none_values(value) if isinstance(value, (dict, list)) else value for value in example]
elif isinstance(example, Mapping):
return {
key: remove_none_values(value) if isinstance(value, (dict, list)) else value
for key, value in example.items()
if value is not None
}
else:
raise TypeError("Input must be a list or a dictionary.")
def get_dataset_column_names(dataset: Union[Dataset, IterableDataset]) -> list[str]: def get_dataset_column_names(dataset: Union[Dataset, IterableDataset]) -> list[str]:
return list(next(iter(dataset)).keys()) if dataset.column_names is None else dataset.column_names return list(next(iter(dataset)).keys()) if dataset.column_names is None else dataset.column_names

View File

@ -18,11 +18,12 @@ import json
import os import os
import random import random
import socket import socket
from collections.abc import Sequence, Sized import warnings
from collections.abc import Mapping, Sequence, Sized
from dataclasses import dataclass, field from dataclasses import dataclass, field
from importlib.metadata import version from importlib.metadata import version
from itertools import accumulate from itertools import accumulate
from typing import Any, Literal, Optional, Union from typing import Any, Literal, Optional, TypeVar, Union
import numpy as np import numpy as np
import pandas as pd import pandas as pd
@ -219,6 +220,10 @@ class RewardDataCollatorWithPadding:
r""" r"""
Reward DataCollator class that pads the inputs to the maximum length of the batch. Reward DataCollator class that pads the inputs to the maximum length of the batch.
> [!WARNING]
> This class is deprecated and will be removed in version 0.27.0. Please use
`trl.trainer.reward_trainer.DataCollatorForPreference` instead.
Args: Args:
tokenizer (`PreTrainedTokenizerBase`): tokenizer (`PreTrainedTokenizerBase`):
The tokenizer used for encoding the data. The tokenizer used for encoding the data.
@ -235,6 +240,14 @@ class RewardDataCollatorWithPadding:
pad_to_multiple_of: Optional[int] = None pad_to_multiple_of: Optional[int] = None
return_tensors: str = "pt" return_tensors: str = "pt"
def __init__(self, *args, **kwargs) -> None:
warnings.warn(
"The `RewardDataCollatorWithPadding` is deprecated and will be removed in version 0.27.0. Please use "
"`trl.trainer.reward_trainer.DataCollatorForPreference` instead.",
DeprecationWarning,
)
super().__init__(*args, **kwargs)
def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]: def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:
features_chosen = [] features_chosen = []
features_rejected = [] features_rejected = []
@ -1241,6 +1254,10 @@ def decode_and_strip_padding(inputs: torch.Tensor, tokenizer: PreTrainedTokenize
""" """
Decodes the input tensor and strips the padding tokens. Decodes the input tensor and strips the padding tokens.
> [!WARNING]
> This function is deprecated and will be removed in a version 0.25.0. If you want to keep using it, please copy
> the code into your codebase and use it from there.
Args: Args:
inputs (`torch.Tensor`): inputs (`torch.Tensor`):
The input tensor to be decoded. The input tensor to be decoded.
@ -1251,6 +1268,11 @@ def decode_and_strip_padding(inputs: torch.Tensor, tokenizer: PreTrainedTokenize
`list[str]`: `list[str]`:
The list of decoded strings with padding tokens stripped. The list of decoded strings with padding tokens stripped.
""" """
warnings.warn(
"The function `decode_and_strip_padding` is deprecated and will be removed in a version 0.25.0. If you want "
"to keep using it, please copy the code into your codebase and use it from there.",
DeprecationWarning,
)
decoded = tokenizer.batch_decode(inputs, skip_special_tokens=False) decoded = tokenizer.batch_decode(inputs, skip_special_tokens=False)
return [d.replace(tokenizer.pad_token, "") for d in decoded] return [d.replace(tokenizer.pad_token, "") for d in decoded]
@ -1264,6 +1286,7 @@ def generate_model_card(
wandb_url: Optional[str], wandb_url: Optional[str],
trainer_name: str, trainer_name: str,
trainer_citation: Optional[str] = None, trainer_citation: Optional[str] = None,
template_file: Optional[str] = None,
paper_title: Optional[str] = None, paper_title: Optional[str] = None,
paper_id: Optional[str] = None, paper_id: Optional[str] = None,
comet_url: Optional[str] = None, comet_url: Optional[str] = None,
@ -1290,6 +1313,8 @@ def generate_model_card(
Trainer name. Trainer name.
trainer_citation (`str` or `None`, defaults to `None`): trainer_citation (`str` or `None`, defaults to `None`):
Trainer citation as a BibTeX entry. Trainer citation as a BibTeX entry.
template_file (`str` *optional*):
Template file name located in the `trl/templates` directory. Defaults to `lm_model_card.md`.
paper_title (`str` or `None`, defaults to `None`): paper_title (`str` or `None`, defaults to `None`):
Paper title. Paper title.
paper_id (`str` or `None`, defaults to `None`): paper_id (`str` or `None`, defaults to `None`):
@ -1307,9 +1332,10 @@ def generate_model_card(
model_name=model_name, model_name=model_name,
tags=["generated_from_trainer", *tags], tags=["generated_from_trainer", *tags],
) )
template_file = template_file or "lm_model_card.md"
card = ModelCard.from_template( card = ModelCard.from_template(
card_data, card_data,
template_path=str(pkg_resources.files("trl").joinpath("templates/lm_model_card.md")), template_path=str(pkg_resources.files("trl").joinpath(f"templates/{template_file}")),
base_model=base_model, base_model=base_model,
model_name=model_name, model_name=model_name,
hub_model_id=hub_model_id, hub_model_id=hub_model_id,
@ -1956,6 +1982,41 @@ def truncate_with_protected_tokens(
return torch.stack(truncated_seq), torch.stack(truncated_mask) return torch.stack(truncated_seq), torch.stack(truncated_mask)
TListOrMapping = TypeVar("TListOrMapping", list, Mapping)
def remove_none_values(example: TListOrMapping) -> TListOrMapping:
"""
Recursively removes entries with `None` values from a nested structure (list or dictionary).
Args:
example (`list` or `Mapping`):
Input nested structure (list or dictionary) from which to remove `None`.
Example:
```python
>>> [
... {
... "a": {"aa": None, "ab": 1},
... "b": "my_string",
... }
... ]
>>> remove_none_values(example)
[{'a': {'ab': 1}, 'b': 'my_string'}]
```
"""
if isinstance(example, list):
return [remove_none_values(value) if isinstance(value, (dict, list)) else value for value in example]
elif isinstance(example, Mapping):
return {
key: remove_none_values(value) if isinstance(value, (dict, list)) else value
for key, value in example.items()
if value is not None
}
else:
raise TypeError("Input must be a list or a dictionary.")
def create_model_from_path(model_id: str, **kwargs) -> PreTrainedModel: def create_model_from_path(model_id: str, **kwargs) -> PreTrainedModel:
""" """
Create a model from a given path using the specified initialization arguments. Create a model from a given path using the specified initialization arguments.