Compare commits

...

7 Commits

Author SHA1 Message Date
a932e2796d ⬆️ Bump dev version (#4293) 2025-10-15 18:11:52 -06:00
04fd1203af Release: v0.24 (#4292) 2025-10-15 18:10:10 -06:00
19d2f97932 Deprecate BestOfNSampler (#4291)
Co-authored-by: behroozazarkhalili <ermiaazarkhalili>
Co-authored-by: Behrooz Azarkhalili <80390531+behroozazarkhalili@users.noreply.github.com>
2025-10-15 18:06:34 -06:00
31caf64778 Remove unused commands directory (#4258)
Co-authored-by: behroozazarkhalili <ermiaazarkhalili>
2025-10-15 18:01:50 -06:00
8e2d5516ca Add accuracy reward (#4270)
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-10-15 18:01:07 -06:00
94aac4a101 Remove how_to_train.md: outdated training FAQ (#4267)
Co-authored-by: behroozazarkhalili <ermiaazarkhalili>
2025-10-15 23:49:04 +00:00
26b7c2507e Add support for token_type_ids in DPOTrainer (#4285)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
2025-10-15 17:33:35 -06:00
29 changed files with 246 additions and 529 deletions

View File

@ -102,13 +102,6 @@ jobs:
source .venv/bin/activate
make slow_tests
- name: Run end-to-end examples tests on multi GPU
if: always()
run: |
source .venv/bin/activate
uv pip install deepspeed
make test_examples
- name: Generate Reports
if: always()
run: |

View File

@ -24,7 +24,7 @@ jobs:
steps:
- name: Git checkout
uses: actions/checkout@v4
with: { ref: v0.23-release }
with: { ref: v0.24-release }
- name: Set up Python 3.12
uses: actions/setup-python@v5

View File

@ -31,4 +31,4 @@ keywords:
- pytorch
- transformers
license: Apache-2.0
version: "0.23"
version: "0.24"

View File

@ -1,9 +1,8 @@
.PHONY: test precommit common_tests slow_tests test_examples tests_gpu test_experimental
.PHONY: test precommit common_tests slow_tests tests_gpu test_experimental
check_dirs := examples tests trl
ACCELERATE_CONFIG_PATH = `pwd`/examples/accelerate_configs
COMMAND_FILES_PATH = `pwd`/commands
test:
pytest -n auto -m "not slow and not low_priority" -s -v --reruns 5 --reruns-delay 1 --only-rerun '(OSError|Timeout|HTTPError.*502|HTTPError.*504||not less than or equal to 0.01)' tests/
@ -16,18 +15,5 @@ precommit:
slow_tests:
pytest -m "slow" tests/ $(if $(IS_GITHUB_CI),--report-log "slow_tests.log",)
test_examples:
touch temp_results_sft_tests.txt
for file in $(ACCELERATE_CONFIG_PATH)/*.yaml; do \
TRL_ACCELERATE_CONFIG=$${file} bash $(COMMAND_FILES_PATH)/run_sft.sh; \
echo $$?','$${file} >> temp_results_sft_tests.txt; \
done
touch temp_results_dpo_tests.txt
for file in $(ACCELERATE_CONFIG_PATH)/*.yaml; do \
TRL_ACCELERATE_CONFIG=$${file} bash $(COMMAND_FILES_PATH)/run_dpo.sh; \
echo $$?','$${file} >> temp_results_dpo_tests.txt; \
done
test_experimental:
pytest -k "experimental"

View File

@ -1 +1 @@
0.24.0.dev0
0.25.0.dev0

View File

@ -1,58 +0,0 @@
#!/bin/bash
# This script runs an SFT example end-to-end on a tiny model using different possible configurations
# but defaults to QLoRA + PEFT
OUTPUT_DIR="test_dpo/"
MODEL_NAME="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
DATASET_NAME="trl-internal-testing/hh-rlhf-helpful-base-trl-style"
MAX_STEPS=5
BATCH_SIZE=2
SEQ_LEN=128
# Handle extra arguments in case one passes accelerate configs.
EXTRA_ACCELERATE_ARGS=""
EXTRA_TRAINING_ARGS="""--use_peft \
--load_in_4bit
"""
# This is a hack to get the number of available GPUs
NUM_GPUS=2
if [[ "${TRL_ACCELERATE_CONFIG}" == "" ]]; then
EXTRA_ACCELERATE_ARGS=""
else
EXTRA_ACCELERATE_ARGS="--config_file $TRL_ACCELERATE_CONFIG"
# For DeepSpeed configs we need to set the `--fp16` flag to comply with our configs exposed
# on `examples/accelerate_configs` and our runners do not support bf16 mixed precision training.
if [[ $TRL_ACCELERATE_CONFIG == *"deepspeed"* ]]; then
EXTRA_TRAINING_ARGS="--fp16"
else
echo "Keeping QLoRA + PEFT"
fi
fi
CMD="""
accelerate launch $EXTRA_ACCELERATE_ARGS \
--num_processes $NUM_GPUS \
--mixed_precision 'fp16' \
`pwd`/trl/scripts/dpo.py \
--model_name_or_path $MODEL_NAME \
--dataset_name $DATASET_NAME \
--output_dir $OUTPUT_DIR \
--max_steps $MAX_STEPS \
--per_device_train_batch_size $BATCH_SIZE \
--max_length $SEQ_LEN \
$EXTRA_TRAINING_ARGS
"""
echo "Starting program..."
{ # try
echo $CMD
eval "$CMD"
} || { # catch
# save log for exception
echo "Operation Failed!"
exit 1
}
exit 0

View File

@ -1,59 +0,0 @@
#!/bin/bash
# This script runs an SFT example end-to-end on a tiny model using different possible configurations
# but defaults to QLoRA + PEFT
OUTPUT_DIR="test_sft/"
MODEL_NAME="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
DATASET_NAME="stanfordnlp/imdb"
MAX_STEPS=5
BATCH_SIZE=2
SEQ_LEN=128
# Handle extra arguments in case one passes accelerate configs.
EXTRA_ACCELERATE_ARGS=""
EXTRA_TRAINING_ARGS="""--use_peft \
--load_in_4bit
"""
# Set your number of GPUs here
NUM_GPUS=2
if [[ "${TRL_ACCELERATE_CONFIG}" == "" ]]; then
EXTRA_ACCELERATE_ARGS=""
else
EXTRA_ACCELERATE_ARGS="--config_file $TRL_ACCELERATE_CONFIG"
# For DeepSpeed configs we need to set the `--fp16` flag to comply with our configs exposed
# on `examples/accelerate_configs` and our runners do not support bf16 mixed precision training.
if [[ $TRL_ACCELERATE_CONFIG == *"deepspeed"* ]]; then
EXTRA_TRAINING_ARGS="--fp16"
else
echo "Keeping QLoRA + PEFT"
fi
fi
CMD="""
accelerate launch $EXTRA_ACCELERATE_ARGS \
--num_processes $NUM_GPUS \
--mixed_precision 'fp16' \
`pwd`/trl/scripts/sft.py \
--model_name $MODEL_NAME \
--dataset_name $DATASET_NAME \
--output_dir $OUTPUT_DIR \
--max_steps $MAX_STEPS \
--per_device_train_batch_size $BATCH_SIZE \
--max_length $SEQ_LEN \
$EXTRA_TRAINING_ARGS
"""
echo "Starting program..."
{ # try
echo $CMD
eval "$CMD"
} || { # catch
# save log for exception
echo "Operation Failed!"
exit 1
}
exit 0

View File

@ -13,8 +13,6 @@
title: Paper Index
- local: experimental
title: Experimental
- local: how_to_train
title: Training FAQ
title: Conceptual Guides
- sections:
- local: clis

View File

@ -1,5 +1,8 @@
# Best of N sampling: Alternative ways to get better model output without RL based fine-tuning
> [!WARNING]
> Best-of-N sampling is deprecated and will be removed in TRL 0.25.0.
Within the extras module is the `best-of-n` sampler class that serves as an alternative method of generating better model output.
As to how it fares against the RL based fine-tuning, please look in the `examples` directory for a comparison example

View File

@ -1,63 +0,0 @@
# Training FAQ
## What Metrics Should I Look at?
When performing classical supervised fine-tuning of language models, the loss (especially the validation loss) serves as a good indicator of the training progress. However, in Reinforcement Learning (RL), the loss becomes less informative about the model's performance, and its value may fluctuate while the actual performance improves.
To address this, we recommend focusing on two key metrics first:
**Mean Reward**: The primary goal is to maximize the reward achieved by the model during RL training.
**Objective KL Divergence**: KL divergence (Kullback-Leibler divergence) measures the dissimilarity between two probability distributions. In the context of RL training, we use it to quantify the difference between the current model and a reference model. Ideally, we want to keep the KL divergence between 0 and 10 to ensure the model's generated text remains close to what the reference model produces.
## Why Do We Use a Reference Model, and What's the Purpose of KL Divergence?
When training RL models, optimizing solely for reward may lead to unexpected behaviors, where the model exploits the environment in ways that don't align with good language generation. In the case of RLHF, we use a reward model trained to predict whether a generated text is highly ranked by humans.
However, the RL model being optimized against the reward model may learn patterns that yield high reward but do not represent good language. This can result in extreme cases where the model generates texts with excessive exclamation marks or emojis to maximize the reward. In some worst-case scenarios, the model may generate patterns completely unrelated to natural language yet receive high rewards, similar to adversarial attacks.
<div style="text-align: center">
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/kl-example.png">
<p style="text-align: center;"> <b>Figure:</b> Samples without a KL penalty from <a href="https://huggingface.co/papers/1909.08593">https://huggingface.co/papers/1909.08593</a>. </p>
</div>
To address this issue, we add a penalty to the reward function based on the KL divergence between the current model and the reference model. By doing this, we encourage the model to stay close to what the reference model generates.
## What Is the Concern with Negative KL Divergence?
If you generate text by purely sampling from the model distribution things work fine in general. But when you use the `generate` method there are a few caveats because it does not always purely sample depending on the settings which can cause KL-divergence to go negative. Essentially when the active model achieves `log_p_token_active < log_p_token_ref` we get negative KL-div. This can happen in several cases:
- **top-k sampling**: the model can smooth out the probability distribution causing the top-k tokens having a smaller probability than those of the reference model but they still are selected
- **min_length**: this ignores the EOS token until `min_length` is reached. thus the model can assign a very low log prob to the EOS token and very high probs to all others until min_length is reached
These are just a few examples. Why is negative KL an issue? The total reward `R` is computed `R = r - beta * KL` so if the model can learn how to drive KL-divergence negative it effectively gets a positive reward. In many cases it can be much easier to exploit such a bug in the generation than actually learning the reward function. In addition the KL can become arbitrarily small thus the actual reward can be very small compared to it.
So how should you generate text for PPO training? Let's have a look!
## How to generate text for training?
In order to avoid the KL issues described above we recommend to use the following settings:
```python
generation_kwargs = {
"min_length": -1, # don't ignore the EOS token (see above)
"top_k": 0.0, # no top-k sampling
"top_p": 1.0, # no nucleus sampling
"do_sample": True, # yes, we want to sample
"pad_token_id": tokenizer.eos_token_id, # most decoder models don't have a padding token - use EOS token instead
"max_new_tokens": 32, # specify how many tokens you want to generate at most
}
```
With these settings we usually don't encounter any issues. You can also experiment with other settings but if you encounter issues with negative KL-divergence try to go back to these and see if they persist.
## How can debug your own use-case?
Debugging the RL pipeline can be challenging due to its complexity. Here are some tips and suggestions to make the process easier:
- **Start from a working example**: Begin with a working example from the trl repository and gradually modify it to fit your specific use-case. Changing everything at once can make it difficult to identify the source of potential issues. For example, you can start by replacing the model in the example and once you figure out the best hyperparameters try to switch to your dataset and reward model. If you change everything at once you won't know where a potential problem comes from.
- **Start small, scale later**: Training large models can be very slow and take several hours or days until you see any improvement. For debugging this is not a convenient timescale so try to use small model variants during the development phase and scale up once that works. That being said you sometimes have to be careful as small models might not have the capacity to solve a complicated task either.
- **Start simple**: Try to start with a minimal example and build complexity from there. Your use-case might require for example a complicated reward function consisting of many different rewards - try to use one signal first and see if you can optimize that and then add more complexity after that.
- **Inspect the generations**: It's always a good idea to inspect what the model is generating. Maybe there is a bug in your post-processing or your prompt. Due to bad settings you might cut-off generations too soon. These things are very hard to see on the metrics but very obvious if you look at the generations.
- **Inspect the reward model**: If your reward is not improving over time maybe there's an issue with the reward model. You can look at extreme cases to see if it does what it should: e.g. in the sentiment case you can check if simple positive and negative examples really get different rewards. And you can look at the distribution of your dataset. Finally, maybe the reward is dominated by the query which the model can't affect so you might need to normalize this (e.g. reward of query+response minus reward of the query).
These are just a few tips that we find helpful - if you have more useful tricks feel free to open a PR to add them as well!

View File

@ -91,7 +91,6 @@ trl reward --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
- [SFT Trainer](sft_trainer) - Complete SFT guide
- [DPO Trainer](dpo_trainer) - Preference alignment
- [GRPO Trainer](grpo_trainer) - Group relative policy optimization
- [Training FAQ](how_to_train) - Common questions
### 🚀 Scale Up
@ -141,4 +140,4 @@ Try adjusting the learning rate:
training_args = SFTConfig(learning_rate=2e-5) # Good starting point
```
For more help, see our [Training FAQ](how_to_train) or open an [issue on GitHub](https://github.com/huggingface/trl/issues).
For more help, open an [issue on GitHub](https://github.com/huggingface/trl/issues).

View File

@ -2,14 +2,14 @@
This module contains some useful reward functions, primarily intended for use with the [`GRPOTrainer`] and [`RLOOTrainer`].
## Format rewards
## accuracy_reward
### think_format_reward
[[autodoc]] rewards.accuracy_reward
## think_format_reward
[[autodoc]] rewards.think_format_reward
## Other rewards
### get_soft_overlong_punishment
## get_soft_overlong_punishment
[[autodoc]] rewards.get_soft_overlong_punishment

View File

@ -70,8 +70,6 @@ import os
import torch
from datasets import load_dataset
from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse, verify
from trl import (
GRPOConfig,
@ -83,7 +81,7 @@ from trl import (
get_peft_config,
get_quantization_config,
)
from trl.rewards import think_format_reward
from trl.rewards import accuracy_reward, think_format_reward
# Enable logging in a Hugging Face Space
@ -149,54 +147,6 @@ if __name__ == "__main__":
train_dataset = dataset["train"]
eval_dataset = dataset["test"] if training_args.eval_strategy != "no" else None
################
# Reward Function for Training
################
def accuracy_reward(completions, solution: list[str], **kwargs):
"""Reward function that checks if the completion matches the ground truth.
- If both gold and prediction are parseable → use math verification.
- If not parseable → compare as normalized text.
"""
rewards = []
contents = [completion[0]["content"] for completion in completions]
for content, sol in zip(contents, solution):
try:
gold_parsed = parse(sol, extraction_mode="first_match")
except Exception:
gold_parsed = []
if len(gold_parsed) != 0:
# Try parsing predicted answer too
try:
answer_parsed = parse(
content,
extraction_config=[
LatexExtractionConfig(
normalization_config=NormalizationConfig(
nits=False,
malformed_operators=False,
basic_latex=True,
boxed="all",
units=True,
),
boxed_match_priority=0,
try_extract_without_anchor=False,
)
],
extraction_mode="first_match",
)
reward = float(verify(gold_parsed, answer_parsed))
except Exception as e:
print(f"verify failed: {e}, answer: {content}, gold: {sol}")
reward = None
else:
# fallback to text match
reward = float(content.strip().lower() == sol.strip().lower())
rewards.append(reward)
return rewards
################
# Training
################

View File

@ -57,8 +57,6 @@ import os
import torch
from datasets import load_dataset
from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse, verify
from trl import (
GRPOConfig,
@ -70,7 +68,7 @@ from trl import (
get_peft_config,
get_quantization_config,
)
from trl.rewards import think_format_reward
from trl.rewards import accuracy_reward, think_format_reward
# Enable logging in a Hugging Face Space
@ -120,54 +118,6 @@ if __name__ == "__main__":
train_dataset = train_dataset.remove_columns(["messages", "problem"])
eval_dataset = eval_dataset.remove_columns(["messages", "problem"])
################
# Reward Function for Training
################
def accuracy_reward(completions, solution: list[str], **kwargs):
"""Reward function that checks if the completion matches the ground truth.
- If both gold and prediction are parseable → use math verification.
- If not parseable → compare as normalized text.
"""
rewards = []
contents = [completion[0]["content"] for completion in completions]
for content, sol in zip(contents, solution):
try:
gold_parsed = parse(sol, extraction_mode="first_match")
except Exception:
gold_parsed = []
if len(gold_parsed) != 0:
# Try parsing predicted answer too
try:
answer_parsed = parse(
content,
extraction_config=[
LatexExtractionConfig(
normalization_config=NormalizationConfig(
nits=False,
malformed_operators=False,
basic_latex=True,
boxed="all",
units=True,
),
boxed_match_priority=0,
try_extract_without_anchor=False,
)
],
extraction_mode="first_match",
)
reward = float(verify(gold_parsed, answer_parsed))
except Exception as e:
print(f"verify failed: {e}, answer: {content}, gold: {sol}")
reward = None
else:
# fallback to text match
reward = float(content.strip().lower() == sol.strip().lower())
rewards.append(reward)
return rewards
################
# Training
################

View File

@ -57,8 +57,6 @@ import os
import torch
from datasets import load_dataset
from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse, verify
from trl import (
GRPOConfig,
@ -70,7 +68,7 @@ from trl import (
get_peft_config,
get_quantization_config,
)
from trl.rewards import think_format_reward
from trl.rewards import accuracy_reward, think_format_reward
# Enable logging in a Hugging Face Space
@ -136,54 +134,6 @@ if __name__ == "__main__":
train_dataset = dataset["train"]
eval_dataset = dataset["test"] if training_args.eval_strategy != "no" else None
################
# Reward Function for Training
################
def accuracy_reward(completions, solution: list[str], **kwargs):
"""Reward function that checks if the completion matches the ground truth.
- If both gold and prediction are parseable → use math verification.
- If not parseable → compare as normalized text.
"""
rewards = []
contents = [completion[0]["content"] for completion in completions]
for content, sol in zip(contents, solution):
try:
gold_parsed = parse(sol, extraction_mode="first_match")
except Exception:
gold_parsed = []
if len(gold_parsed) != 0:
# Try parsing predicted answer too
try:
answer_parsed = parse(
content,
extraction_config=[
LatexExtractionConfig(
normalization_config=NormalizationConfig(
nits=False,
malformed_operators=False,
basic_latex=True,
boxed="all",
units=True,
),
boxed_match_priority=0,
try_extract_without_anchor=False,
)
],
extraction_mode="first_match",
)
reward = float(verify(gold_parsed, answer_parsed))
except Exception as e:
print(f"verify failed: {e}, answer: {content}, gold: {sol}")
reward = None
else:
# fallback to text match
reward = float(content.strip().lower() == sol.strip().lower())
rewards.append(reward)
return rewards
################
# Training
################

View File

@ -87,8 +87,6 @@ import os
import torch
import transformers
from datasets import load_dataset
from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse, verify
from transformers import AutoConfig, AutoProcessor, GenerationConfig
from trl import (
@ -102,7 +100,7 @@ from trl import (
get_peft_config,
get_quantization_config,
)
from trl.rewards import think_format_reward
from trl.rewards import accuracy_reward, think_format_reward
# Enable logging in a Hugging Face Space
@ -192,54 +190,6 @@ if __name__ == "__main__":
train_dataset = dataset["train"]
eval_dataset = dataset["test"] if training_args.eval_strategy != "no" else None
################
# Reward Function for Training (same as GRPO VLM)
################
def accuracy_reward(completions, solution: list[str], **kwargs):
"""Reward function that checks if the completion matches the ground truth.
- If both gold and prediction are parseable → use math verification.
- If not parseable → compare as normalized text.
"""
rewards = []
contents = [completion[0]["content"] for completion in completions]
for content, sol in zip(contents, solution):
try:
gold_parsed = parse(sol, extraction_mode="first_match")
except Exception:
gold_parsed = []
if len(gold_parsed) != 0:
# Try parsing predicted answer too
try:
answer_parsed = parse(
content,
extraction_config=[
LatexExtractionConfig(
normalization_config=NormalizationConfig(
nits=False,
malformed_operators=False,
basic_latex=True,
boxed="all",
units=True,
),
boxed_match_priority=0,
try_extract_without_anchor=False,
)
],
extraction_mode="first_match",
)
reward = float(verify(gold_parsed, answer_parsed))
except Exception as e:
print(f"verify failed: {e}, answer: {content}, gold: {sol}")
reward = None
else:
# fallback to text match
reward = float(content.strip().lower() == sol.strip().lower())
rewards.append(reward)
return rewards
################
# Training
################

View File

@ -33,12 +33,10 @@ import os
import torch
from datasets import load_dataset
from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse, verify
from peft import LoraConfig
from trl import RLOOConfig, RLOOTrainer
from trl.rewards import think_format_reward
from trl.rewards import accuracy_reward, think_format_reward
# Enable logging in a Hugging Face Space
@ -67,52 +65,6 @@ def main():
train_dataset = train_dataset.map(make_conversation, remove_columns=["messages", "problem"])
eval_dataset = eval_dataset.map(make_conversation, remove_columns=["messages", "problem"])
# Reward function for training
def accuracy_reward(completions, solution: list[str], **kwargs):
"""Reward function that checks if the completion matches the ground truth.
- If both gold and prediction are parseable → use math verification.
- If not parseable → compare as normalized text.
"""
rewards = []
contents = [completion[0]["content"] for completion in completions]
for content, sol in zip(contents, solution):
try:
gold_parsed = parse(sol, extraction_mode="first_match")
except Exception:
gold_parsed = []
if len(gold_parsed) != 0:
# Try parsing predicted answer too
try:
answer_parsed = parse(
content,
extraction_config=[
LatexExtractionConfig(
normalization_config=NormalizationConfig(
nits=False,
malformed_operators=False,
basic_latex=True,
boxed="all",
units=True,
),
boxed_match_priority=0,
try_extract_without_anchor=False,
)
],
extraction_mode="first_match",
)
reward = float(verify(gold_parsed, answer_parsed))
except Exception as e:
print(f"verify failed: {e}, answer: {content}, gold: {sol}")
reward = None
else:
# fallback to text match
reward = float(content.strip().lower() == sol.strip().lower())
rewards.append(reward)
return rewards
# Training
training_args = RLOOConfig(
output_dir="Qwen3-0.6B-RLOO",

View File

@ -70,8 +70,6 @@ import os
import torch
from datasets import load_dataset
from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse, verify
from trl import (
ModelConfig,
@ -83,7 +81,7 @@ from trl import (
get_peft_config,
get_quantization_config,
)
from trl.rewards import think_format_reward
from trl.rewards import accuracy_reward, think_format_reward
# Enable logging in a Hugging Face Space
@ -149,54 +147,6 @@ if __name__ == "__main__":
train_dataset = dataset["train"]
eval_dataset = dataset["test"] if training_args.eval_strategy != "no" else None
################
# Reward Function for Training
################
def accuracy_reward(completions, solution: list[str], **kwargs):
"""Reward function that checks if the completion matches the ground truth.
- If both gold and prediction are parseable → use math verification.
- If not parseable → compare as normalized text.
"""
rewards = []
contents = [completion[0]["content"] for completion in completions]
for content, sol in zip(contents, solution):
try:
gold_parsed = parse(sol, extraction_mode="first_match")
except Exception:
gold_parsed = []
if len(gold_parsed) != 0:
# Try parsing predicted answer too
try:
answer_parsed = parse(
content,
extraction_config=[
LatexExtractionConfig(
normalization_config=NormalizationConfig(
nits=False,
malformed_operators=False,
basic_latex=True,
boxed="all",
units=True,
),
boxed_match_priority=0,
try_extract_without_anchor=False,
)
],
extraction_mode="first_match",
)
reward = float(verify(gold_parsed, answer_parsed))
except Exception as e:
print(f"verify failed: {e}, answer: {content}, gold: {sol}")
reward = None
else:
# fallback to text match
reward = float(content.strip().lower() == sol.strip().lower())
rewards.append(reward)
return rewards
################
# Training
################

View File

@ -89,6 +89,9 @@ vlm = [
"torchvision",
"num2words==0.5.14"
]
math_verify = [
"math-verify>=0.5.2",
]
dev = [
# bco
"scikit-learn",

View File

@ -1422,6 +1422,7 @@ class TestDPOVisionTrainer(TrlTestCase):
# ("trl-internal-testing/tiny-PaliGemmaForConditionalGeneration",),
("trl-internal-testing/tiny-LlavaForConditionalGeneration",),
("trl-internal-testing/tiny-LlavaNextForConditionalGeneration",),
("trl-internal-testing/tiny-Gemma3ForConditionalGeneration",),
]
)
def test_vdpo_trainer(self, model_id):

View File

@ -13,9 +13,9 @@
# limitations under the License.
from trl.rewards import get_soft_overlong_punishment, think_format_reward
from trl.rewards import accuracy_reward, get_soft_overlong_punishment, think_format_reward
from .testing_utils import TrlTestCase
from .testing_utils import TrlTestCase, require_math_latex
class TestThinkFormatReward(TrlTestCase):
@ -85,3 +85,60 @@ class TestSoftOverlongPunishmentReward:
completion_ids = [[1] * 90] # 90 is between 80 and 100
rewards = reward_fn(completion_ids)
assert round(abs(rewards[0] - -0.5), 4) == 0
class TestAccuracyReward:
@require_math_latex
def test_accuracy_reward_correct_answer(self):
"""Test accuracy_reward with a correct answer."""
completion = [[{"content": r"\boxed{\frac{63}{400}}"}], [{"content": r"\boxed{\frac{63}{400}}"}]]
solution = [r"\frac{63}{400}", "63/400"]
rewards = accuracy_reward(completion, solution)
assert rewards[0] == 1.0
assert rewards[1] == 1.0
@require_math_latex
def test_accuracy_reward_wrong_answer(self):
"""Test accuracy_reward with an incorrect answer."""
completion = [[{"content": r"\boxed{\frac{64}{400}}"}]]
solution = [r"\frac{63}{400}"]
rewards = accuracy_reward(completion, solution)
assert rewards[0] == 0.0
@require_math_latex
def test_accuracy_reward_wrong_answer_no_latex(self):
"""Test accuracy_reward with an incorrect answer and gold solution with no latex."""
completion = [[{"content": r"\boxed{3}"}]]
solution = ["6"]
rewards = accuracy_reward(completion, solution)
assert rewards[0] == 0.0
@require_math_latex
def test_accuracy_reward_unparseable_gold(self):
"""Test accuracy_reward with an unparseable gold solution."""
completion = [
[{"content": "Answer is forty two."}],
[{"content": "Some other content."}],
[{"content": r"Answer is \boxed{42}."}],
[{"content": r"Answer is \boxed{\mathbf{42}}."}], # Make response bold
[{"content": r"Answer is \boxed{\textbf{42}}."}], # Different latex command for bold
[{"content": r"Answer is \boxed{42}."}],
[{"content": r"Answer is \boxed{42.3456}."}],
]
solution = [
"Answer is forty two.",
"Answer is forty three.",
"Answer is 42.0", # Decimal point
"Answer is 42 43 okay?", # Extra space
"Answer is 42",
r"Answer is \n\boxed{42}", # Newline in gold solution
"Answer is 42.34560", # Extra trailing zero
]
rewards = accuracy_reward(completion, solution)
assert rewards[0] == 1.0 # Should revert to exact text match
assert rewards[1] == 0.0
assert rewards[2] == 1.0
assert rewards[3] == 1.0
assert rewards[4] == 1.0
assert rewards[5] == 1.0
assert rewards[6] == 1.0 # Should ignore trailing zeros

View File

@ -26,12 +26,19 @@ from transformers.testing_utils import torch_device
from transformers.utils import is_peft_available, is_rich_available, is_vision_available
from trl import BaseBinaryJudge, BasePairwiseJudge
from trl.import_utils import is_joblib_available, is_llm_blender_available, is_mergekit_available, is_vllm_available
from trl.import_utils import (
is_joblib_available,
is_llm_blender_available,
is_math_verify_available,
is_mergekit_available,
is_vllm_available,
)
require_bitsandbytes = pytest.mark.skipif(not is_bitsandbytes_available(), reason="test requires bitsandbytes")
require_comet = pytest.mark.skipif(not is_comet_available(), reason="test requires comet_ml")
require_llm_blender = pytest.mark.skipif(not is_llm_blender_available(), reason="test requires llm-blender")
require_math_latex = pytest.mark.skipif(not is_math_verify_available(), reason="test requires math_verify")
require_mergekit = pytest.mark.skipif(not is_mergekit_available(), reason="test requires mergekit")
require_peft = pytest.mark.skipif(not is_peft_available(), reason="test requires peft")
require_rich = pytest.mark.skipif(not is_rich_available(), reason="test requires rich")

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import Any, Callable, Optional, Union
import torch
@ -42,8 +43,16 @@ class BestOfNSampler:
generation_config ([`~transformers.GenerationConfig`], *optional*):
Generation config passed to the underlying model's `generate` method. See
[`~transformers.GenerationConfig`] for more details.
<Deprecated version="0.24.0">
`BestOfNSampler` is deprecated and will be removed in version 0.25.
</Deprecated>
"""
warnings.warn("`BestOfNSampler` is deprecated and will be removed in TRL 0.25.", FutureWarning, stacklevel=2)
def __init__(
self,
model: PreTrainedModelWrapper,

View File

@ -31,6 +31,7 @@ _fastapi_available = _is_package_available("fastapi")
_joblib_available = _is_package_available("joblib")
_liger_kernel_available, _liger_kernel_version = _is_package_available("liger_kernel", return_version=True)
_llm_blender_available = _is_package_available("llm_blender")
_math_verify_available = _is_package_available("math_verify")
_mergekit_available = _is_package_available("mergekit")
_pydantic_available = _is_package_available("pydantic")
_requests_available = _is_package_available("requests")
@ -61,6 +62,10 @@ def is_llm_blender_available() -> bool:
return _llm_blender_available
def is_math_verify_available() -> bool:
return _math_verify_available
def is_mergekit_available() -> bool:
return _mergekit_available

View File

@ -20,12 +20,14 @@ from ..import_utils import _LazyModule
_import_structure = {
"accuracy_rewards": ["accuracy_reward"],
"format_rewards": ["think_format_reward"],
"other_rewards": ["get_soft_overlong_punishment"],
}
if TYPE_CHECKING:
from .accuracy_rewards import accuracy_reward
from .format_rewards import think_format_reward
from .other_rewards import get_soft_overlong_punishment

View File

@ -0,0 +1,93 @@
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from trl.import_utils import is_math_verify_available
if is_math_verify_available():
from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse, verify
def accuracy_reward(completions: list[list[dict[str, str]]], solution: list[str], **kwargs) -> list[Optional[float]]:
r"""
Reward function that checks if the completion is the same as the ground truth.
- If both gold and prediction are parseable → use math verification.
- If not parseable → compare as normalized text.
Args:
completions (`list[list[dict[str, str]]]`):
List of completions to be evaluated. Each completion must be a list of one message, i.e. a dictionary
containing the key `"content"` with the value being the text of the completion.
solution: (`list[str]`):
List of the raw-text solutions to the questions/problems/prompts.
**kwargs:
Additional keyword arguments. This function does not use them, but they are required in the function
signature to ensure compatibility with trainers like [`GRPOTrainer`].
Example:
```python
>>> from trl.rewards import accuracy_reward
>>> solution = [r"\frac{1}{3}", r"\frac{1}{3}"]
>>> completion = [
... [{"role": "assistant", "content": r"My answer is \boxed{\frac{1}{3}}"}],
... [{"role": "assistant", "content": r"My answer is \boxed{\frac{1}{2}}"}],
... ]
>>> accuracy_reward(completion, solution)
[1.0, 0.0]
```
"""
if not is_math_verify_available():
raise ImportError("Please install the `math_verify` package to use accuracy_reward")
contents = [completion[0]["content"] for completion in completions]
rewards = []
for content, sol in zip(contents, solution):
gold_parsed = parse(
sol,
extraction_mode="first_match",
)
if len(gold_parsed) != 0:
# We require the answer to be provided in correct latex (no malformed operators)
answer_parsed = parse(
content,
extraction_config=[
LatexExtractionConfig(
normalization_config=NormalizationConfig(
nits=False,
malformed_operators=False,
basic_latex=True,
boxed="all",
units=True,
),
# Ensures that boxed is tried first
boxed_match_priority=0,
try_extract_without_anchor=False,
)
],
extraction_mode="first_match",
)
# Compute binary rewards if verifiable, `None` otherwise to skip this example
try:
reward = float(verify(gold_parsed, answer_parsed))
except Exception:
reward = None
else:
# If the gold solution is not parseable, we assign `None` to skip this example
reward = float(content.strip().lower() == sol.strip().lower())
rewards.append(reward)
return rewards

View File

@ -41,7 +41,7 @@ from trl import (
get_dataset,
get_peft_config,
)
from trl.rewards import get_soft_overlong_punishment, think_format_reward
from trl.rewards import accuracy_reward, get_soft_overlong_punishment, think_format_reward
logger = logging.get_logger(__name__)
@ -51,6 +51,7 @@ os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio")
reward_funcs_registry = {
"accuracy_reward": accuracy_reward,
"think_format_reward": think_format_reward,
"get_soft_overlong_punishment": get_soft_overlong_punishment(max_completion_len=1280, soft_punish_cache=256),
}
@ -68,6 +69,7 @@ class GRPOScriptArguments(ScriptArguments):
reward_funcs (`list[str]`, *optional*):
Reward functions to use. Supported values are:
- `"accuracy_reward"`
- `"think_format_reward"`
- `"get_soft_overlong_punishment"` (used value are `max_completion_len=1280`, `soft_punish_cache=256`)
- any dotted import path " (e.g., `'my_lib.rewards.custom_reward'`).
@ -83,7 +85,7 @@ class GRPOScriptArguments(ScriptArguments):
reward_funcs: Optional[list[str]] = field(
default=None,
metadata={
"help": "Reward functions to use. Supported values are: `think_format_reward`, "
"help": "Reward functions to use. Supported values are: `accuracy_reward`, `think_format_reward`, "
"`get_soft_overlong_punishment` (used value are `max_completion_len=1280`, `soft_punish_cache=256`), or "
"any dotted import path (e.g., `'my_lib.rewards.custom_reward'`)."
},

View File

@ -41,7 +41,7 @@ from trl import (
get_dataset,
get_peft_config,
)
from trl.rewards import get_soft_overlong_punishment, think_format_reward
from trl.rewards import accuracy_reward, get_soft_overlong_punishment, think_format_reward
logger = logging.get_logger(__name__)
@ -51,6 +51,7 @@ os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio")
reward_funcs_registry = {
"accuracy_reward": accuracy_reward,
"think_format_reward": think_format_reward,
"get_soft_overlong_punishment": get_soft_overlong_punishment(max_completion_len=1280, soft_punish_cache=256),
}
@ -68,6 +69,7 @@ class RLOOScriptArguments(ScriptArguments):
reward_funcs (`list[str]`, *optional*):
Reward functions to use. Supported values are:
- `"accuracy_reward"`
- `"think_format_reward"`
- `"get_soft_overlong_punishment"` (used value are `max_completion_len=1280`, `soft_punish_cache=256`)
- any dotted import path " (e.g., `'my_lib.rewards.custom_reward'`).
@ -83,7 +85,7 @@ class RLOOScriptArguments(ScriptArguments):
reward_funcs: Optional[list[str]] = field(
default=None,
metadata={
"help": "Reward functions to use. Supported values are: `think_format_reward`, "
"help": "Reward functions to use. Supported values are: `accuracy_reward`, `think_format_reward`, "
"`get_soft_overlong_punishment` (used value are `max_completion_len=1280`, `soft_punish_cache=256`), or "
"any dotted import path (e.g., `'my_lib.rewards.custom_reward'`)."
},

View File

@ -177,6 +177,9 @@ class DataCollatorForPreference(DataCollatorMixin):
if "ref_chosen_logps" in examples[0] and "ref_rejected_logps" in examples[0]:
output["ref_chosen_logps"] = ref_chosen_logps
output["ref_rejected_logps"] = ref_rejected_logps
if "token_type_ids" in examples[0]:
token_type_ids = [torch.tensor(example["token_type_ids"]) for example in examples]
output["token_type_ids"] = pad(token_type_ids, padding_value=0, padding_side="left")
return output
@ -790,6 +793,8 @@ class DPOTrainer(BaseTrainer):
output["pixel_attention_mask"] = processed_features["pixel_attention_mask"][0]
if "image_sizes" in processed_features:
output["image_sizes"] = processed_features["image_sizes"][0]
if "token_type_ids" in processed_features:
output["token_type_ids"] = processed_features["token_type_ids"][0]
return output
@ -804,6 +809,7 @@ class DPOTrainer(BaseTrainer):
"chosen_input_ids",
"rejected_input_ids",
"image_sizes",
"token_type_ids",
"ref_chosen_logps",
"ref_rejected_logps",
]
@ -991,6 +997,8 @@ class DPOTrainer(BaseTrainer):
)
if "image_sizes" in batch:
output["image_sizes"] = torch.cat([batch["image_sizes"], batch["image_sizes"]], dim=0)
if "token_type_ids" in batch:
output["token_type_ids"] = torch.cat((batch["token_type_ids"], batch["token_type_ids"]))
# Concatenate the chosen and rejected completions
max_completion_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1])
@ -1516,6 +1524,9 @@ class DPOTrainer(BaseTrainer):
# Concatenate the prompt and completion inputs
input_ids = torch.cat((prompt_input_ids, completion_input_ids), dim=1)
attention_mask = torch.cat((prompt_attention_mask, completion_attention_mask), dim=1)
if "token_type_ids" in concatenated_batch:
prompt_token_type_ids = concatenated_batch["token_type_ids"]
token_type_ids = pad_to_length(prompt_token_type_ids, input_ids.shape[1], 0)
# Mask the prompt but not the completion for the loss
loss_mask = torch.cat(
(torch.zeros_like(prompt_attention_mask), completion_attention_mask),
@ -1528,7 +1539,12 @@ class DPOTrainer(BaseTrainer):
# Flush left to reduce the memory usage
# [[0, 0, x, x, x, x], -> [[x, x, x, x],
# [0, x, x, x, 0, 0]] [x, x, x, 0]]
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
if "token_type_ids" in concatenated_batch:
attention_mask, input_ids, loss_mask, token_type_ids = flush_left(
attention_mask, input_ids, loss_mask, token_type_ids
)
else:
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
attention_mask = attention_mask[:, : self.max_length]
input_ids = input_ids[:, : self.max_length]
loss_mask = loss_mask[:, : self.max_length]
@ -1536,11 +1552,22 @@ class DPOTrainer(BaseTrainer):
# Flush right before truncating left, then flush left
# [[0, 0, x, x, x, x], -> [[0, 0, x, x],
# [0, x, x, x, 0, 0]] [0, x, x, x]]
attention_mask, input_ids, loss_mask = flush_right(attention_mask, input_ids, loss_mask)
if "token_type_ids" in concatenated_batch:
attention_mask, input_ids, loss_mask, token_type_ids = flush_left(
attention_mask, input_ids, loss_mask, token_type_ids
)
token_type_ids = token_type_ids[:, -self.max_length :]
else:
attention_mask, input_ids, loss_mask = flush_right(attention_mask, input_ids, loss_mask)
input_ids = input_ids[:, -self.max_length :]
attention_mask = attention_mask[:, -self.max_length :]
loss_mask = loss_mask[:, -self.max_length :]
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
if "token_type_ids" in concatenated_batch:
attention_mask, input_ids, loss_mask, token_type_ids = flush_left(
attention_mask, input_ids, loss_mask, token_type_ids
)
else:
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
else:
raise ValueError(
f"Unknown truncation mode: '{self.truncation_mode}'. Should be one of ['keep_end', "
@ -1550,7 +1577,15 @@ class DPOTrainer(BaseTrainer):
# Flush left to reduce the memory usage
# [[0, 0, x, x, x, x], -> [[x, x, x, x],
# [0, x, x, x, 0, 0]] [x, x, x, 0]]
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
if "token_type_ids" in concatenated_batch:
attention_mask, input_ids, loss_mask, token_type_ids = flush_left(
attention_mask, input_ids, loss_mask, token_type_ids
)
else:
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
if "token_type_ids" in concatenated_batch:
model_kwargs["token_type_ids"] = token_type_ids
if self.use_logits_to_keep:
# Compute logits_to_keep based on loss_mask pattern: