mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
Compare commits
13 Commits
7e9c6e45d5
...
v0.24-rele
Author | SHA1 | Date | |
---|---|---|---|
04fd1203af | |||
19d2f97932 | |||
31caf64778 | |||
8e2d5516ca | |||
94aac4a101 | |||
26b7c2507e | |||
aa25c2697c | |||
93c7d88563 | |||
c7c041ecc8 | |||
ef40c047aa | |||
7e0adbc552 | |||
773afd9314 | |||
966b397201 |
7
.github/workflows/slow-tests.yml
vendored
7
.github/workflows/slow-tests.yml
vendored
@ -102,13 +102,6 @@ jobs:
|
||||
source .venv/bin/activate
|
||||
make slow_tests
|
||||
|
||||
- name: Run end-to-end examples tests on multi GPU
|
||||
if: always()
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
uv pip install deepspeed
|
||||
make test_examples
|
||||
|
||||
- name: Generate Reports
|
||||
if: always()
|
||||
run: |
|
||||
|
2
.github/workflows/tests_latest.yml
vendored
2
.github/workflows/tests_latest.yml
vendored
@ -24,7 +24,7 @@ jobs:
|
||||
steps:
|
||||
- name: Git checkout
|
||||
uses: actions/checkout@v4
|
||||
with: { ref: v0.23-release }
|
||||
with: { ref: v0.24-release }
|
||||
|
||||
- name: Set up Python 3.12
|
||||
uses: actions/setup-python@v5
|
||||
|
@ -31,4 +31,4 @@ keywords:
|
||||
- pytorch
|
||||
- transformers
|
||||
license: Apache-2.0
|
||||
version: "0.23"
|
||||
version: "0.24"
|
||||
|
16
Makefile
16
Makefile
@ -1,9 +1,8 @@
|
||||
.PHONY: test precommit common_tests slow_tests test_examples tests_gpu test_experimental
|
||||
.PHONY: test precommit common_tests slow_tests tests_gpu test_experimental
|
||||
|
||||
check_dirs := examples tests trl
|
||||
|
||||
ACCELERATE_CONFIG_PATH = `pwd`/examples/accelerate_configs
|
||||
COMMAND_FILES_PATH = `pwd`/commands
|
||||
|
||||
test:
|
||||
pytest -n auto -m "not slow and not low_priority" -s -v --reruns 5 --reruns-delay 1 --only-rerun '(OSError|Timeout|HTTPError.*502|HTTPError.*504||not less than or equal to 0.01)' tests/
|
||||
@ -16,18 +15,5 @@ precommit:
|
||||
slow_tests:
|
||||
pytest -m "slow" tests/ $(if $(IS_GITHUB_CI),--report-log "slow_tests.log",)
|
||||
|
||||
test_examples:
|
||||
touch temp_results_sft_tests.txt
|
||||
for file in $(ACCELERATE_CONFIG_PATH)/*.yaml; do \
|
||||
TRL_ACCELERATE_CONFIG=$${file} bash $(COMMAND_FILES_PATH)/run_sft.sh; \
|
||||
echo $$?','$${file} >> temp_results_sft_tests.txt; \
|
||||
done
|
||||
|
||||
touch temp_results_dpo_tests.txt
|
||||
for file in $(ACCELERATE_CONFIG_PATH)/*.yaml; do \
|
||||
TRL_ACCELERATE_CONFIG=$${file} bash $(COMMAND_FILES_PATH)/run_dpo.sh; \
|
||||
echo $$?','$${file} >> temp_results_dpo_tests.txt; \
|
||||
done
|
||||
|
||||
test_experimental:
|
||||
pytest -k "experimental"
|
||||
|
@ -1,58 +0,0 @@
|
||||
#!/bin/bash
|
||||
# This script runs an SFT example end-to-end on a tiny model using different possible configurations
|
||||
# but defaults to QLoRA + PEFT
|
||||
OUTPUT_DIR="test_dpo/"
|
||||
MODEL_NAME="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
|
||||
DATASET_NAME="trl-internal-testing/hh-rlhf-helpful-base-trl-style"
|
||||
MAX_STEPS=5
|
||||
BATCH_SIZE=2
|
||||
SEQ_LEN=128
|
||||
|
||||
# Handle extra arguments in case one passes accelerate configs.
|
||||
EXTRA_ACCELERATE_ARGS=""
|
||||
EXTRA_TRAINING_ARGS="""--use_peft \
|
||||
--load_in_4bit
|
||||
"""
|
||||
|
||||
# This is a hack to get the number of available GPUs
|
||||
NUM_GPUS=2
|
||||
|
||||
if [[ "${TRL_ACCELERATE_CONFIG}" == "" ]]; then
|
||||
EXTRA_ACCELERATE_ARGS=""
|
||||
else
|
||||
EXTRA_ACCELERATE_ARGS="--config_file $TRL_ACCELERATE_CONFIG"
|
||||
# For DeepSpeed configs we need to set the `--fp16` flag to comply with our configs exposed
|
||||
# on `examples/accelerate_configs` and our runners do not support bf16 mixed precision training.
|
||||
if [[ $TRL_ACCELERATE_CONFIG == *"deepspeed"* ]]; then
|
||||
EXTRA_TRAINING_ARGS="--fp16"
|
||||
else
|
||||
echo "Keeping QLoRA + PEFT"
|
||||
fi
|
||||
fi
|
||||
|
||||
|
||||
CMD="""
|
||||
accelerate launch $EXTRA_ACCELERATE_ARGS \
|
||||
--num_processes $NUM_GPUS \
|
||||
--mixed_precision 'fp16' \
|
||||
`pwd`/trl/scripts/dpo.py \
|
||||
--model_name_or_path $MODEL_NAME \
|
||||
--dataset_name $DATASET_NAME \
|
||||
--output_dir $OUTPUT_DIR \
|
||||
--max_steps $MAX_STEPS \
|
||||
--per_device_train_batch_size $BATCH_SIZE \
|
||||
--max_length $SEQ_LEN \
|
||||
$EXTRA_TRAINING_ARGS
|
||||
"""
|
||||
|
||||
echo "Starting program..."
|
||||
|
||||
{ # try
|
||||
echo $CMD
|
||||
eval "$CMD"
|
||||
} || { # catch
|
||||
# save log for exception
|
||||
echo "Operation Failed!"
|
||||
exit 1
|
||||
}
|
||||
exit 0
|
@ -1,59 +0,0 @@
|
||||
#!/bin/bash
|
||||
# This script runs an SFT example end-to-end on a tiny model using different possible configurations
|
||||
# but defaults to QLoRA + PEFT
|
||||
OUTPUT_DIR="test_sft/"
|
||||
MODEL_NAME="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
|
||||
DATASET_NAME="stanfordnlp/imdb"
|
||||
MAX_STEPS=5
|
||||
BATCH_SIZE=2
|
||||
SEQ_LEN=128
|
||||
|
||||
|
||||
# Handle extra arguments in case one passes accelerate configs.
|
||||
EXTRA_ACCELERATE_ARGS=""
|
||||
EXTRA_TRAINING_ARGS="""--use_peft \
|
||||
--load_in_4bit
|
||||
"""
|
||||
|
||||
# Set your number of GPUs here
|
||||
NUM_GPUS=2
|
||||
|
||||
if [[ "${TRL_ACCELERATE_CONFIG}" == "" ]]; then
|
||||
EXTRA_ACCELERATE_ARGS=""
|
||||
else
|
||||
EXTRA_ACCELERATE_ARGS="--config_file $TRL_ACCELERATE_CONFIG"
|
||||
# For DeepSpeed configs we need to set the `--fp16` flag to comply with our configs exposed
|
||||
# on `examples/accelerate_configs` and our runners do not support bf16 mixed precision training.
|
||||
if [[ $TRL_ACCELERATE_CONFIG == *"deepspeed"* ]]; then
|
||||
EXTRA_TRAINING_ARGS="--fp16"
|
||||
else
|
||||
echo "Keeping QLoRA + PEFT"
|
||||
fi
|
||||
fi
|
||||
|
||||
|
||||
CMD="""
|
||||
accelerate launch $EXTRA_ACCELERATE_ARGS \
|
||||
--num_processes $NUM_GPUS \
|
||||
--mixed_precision 'fp16' \
|
||||
`pwd`/trl/scripts/sft.py \
|
||||
--model_name $MODEL_NAME \
|
||||
--dataset_name $DATASET_NAME \
|
||||
--output_dir $OUTPUT_DIR \
|
||||
--max_steps $MAX_STEPS \
|
||||
--per_device_train_batch_size $BATCH_SIZE \
|
||||
--max_length $SEQ_LEN \
|
||||
$EXTRA_TRAINING_ARGS
|
||||
"""
|
||||
|
||||
echo "Starting program..."
|
||||
|
||||
{ # try
|
||||
echo $CMD
|
||||
eval "$CMD"
|
||||
} || { # catch
|
||||
# save log for exception
|
||||
echo "Operation Failed!"
|
||||
exit 1
|
||||
}
|
||||
exit 0
|
@ -13,10 +13,6 @@
|
||||
title: Paper Index
|
||||
- local: experimental
|
||||
title: Experimental
|
||||
- local: how_to_train
|
||||
title: Training FAQ
|
||||
- local: logging
|
||||
title: Understanding Logs
|
||||
title: Conceptual Guides
|
||||
- sections:
|
||||
- local: clis
|
||||
@ -59,8 +55,6 @@
|
||||
title: LoRA Without Regret
|
||||
- local: sentiment_tuning
|
||||
title: Sentiment Tuning
|
||||
- local: using_llama_models
|
||||
title: Training StackLlama
|
||||
- local: multi_adapter_rl
|
||||
title: Multi Adapter RLHF
|
||||
title: Examples
|
||||
|
@ -1,5 +1,8 @@
|
||||
# Best of N sampling: Alternative ways to get better model output without RL based fine-tuning
|
||||
|
||||
> [!WARNING]
|
||||
> Best-of-N sampling is deprecated and will be removed in TRL 0.25.0.
|
||||
|
||||
Within the extras module is the `best-of-n` sampler class that serves as an alternative method of generating better model output.
|
||||
As to how it fares against the RL based fine-tuning, please look in the `examples` directory for a comparison example
|
||||
|
||||
|
@ -1,65 +0,0 @@
|
||||
# Training FAQ
|
||||
|
||||
## What Metrics Should I Look at?
|
||||
|
||||
When performing classical supervised fine-tuning of language models, the loss (especially the validation loss) serves as a good indicator of the training progress. However, in Reinforcement Learning (RL), the loss becomes less informative about the model's performance, and its value may fluctuate while the actual performance improves.
|
||||
|
||||
To address this, we recommend focusing on two key metrics first:
|
||||
|
||||
**Mean Reward**: The primary goal is to maximize the reward achieved by the model during RL training.
|
||||
**Objective KL Divergence**: KL divergence (Kullback-Leibler divergence) measures the dissimilarity between two probability distributions. In the context of RL training, we use it to quantify the difference between the current model and a reference model. Ideally, we want to keep the KL divergence between 0 and 10 to ensure the model's generated text remains close to what the reference model produces.
|
||||
|
||||
However, there are more metrics that can be useful for debugging, check out the [logging section](logging).
|
||||
|
||||
## Why Do We Use a Reference Model, and What's the Purpose of KL Divergence?
|
||||
|
||||
When training RL models, optimizing solely for reward may lead to unexpected behaviors, where the model exploits the environment in ways that don't align with good language generation. In the case of RLHF, we use a reward model trained to predict whether a generated text is highly ranked by humans.
|
||||
|
||||
However, the RL model being optimized against the reward model may learn patterns that yield high reward but do not represent good language. This can result in extreme cases where the model generates texts with excessive exclamation marks or emojis to maximize the reward. In some worst-case scenarios, the model may generate patterns completely unrelated to natural language yet receive high rewards, similar to adversarial attacks.
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/kl-example.png">
|
||||
<p style="text-align: center;"> <b>Figure:</b> Samples without a KL penalty from <a href="https://huggingface.co/papers/1909.08593">https://huggingface.co/papers/1909.08593</a>. </p>
|
||||
</div>
|
||||
|
||||
To address this issue, we add a penalty to the reward function based on the KL divergence between the current model and the reference model. By doing this, we encourage the model to stay close to what the reference model generates.
|
||||
|
||||
## What Is the Concern with Negative KL Divergence?
|
||||
|
||||
If you generate text by purely sampling from the model distribution things work fine in general. But when you use the `generate` method there are a few caveats because it does not always purely sample depending on the settings which can cause KL-divergence to go negative. Essentially when the active model achieves `log_p_token_active < log_p_token_ref` we get negative KL-div. This can happen in several cases:
|
||||
|
||||
- **top-k sampling**: the model can smooth out the probability distribution causing the top-k tokens having a smaller probability than those of the reference model but they still are selected
|
||||
- **min_length**: this ignores the EOS token until `min_length` is reached. thus the model can assign a very low log prob to the EOS token and very high probs to all others until min_length is reached
|
||||
|
||||
These are just a few examples. Why is negative KL an issue? The total reward `R` is computed `R = r - beta * KL` so if the model can learn how to drive KL-divergence negative it effectively gets a positive reward. In many cases it can be much easier to exploit such a bug in the generation than actually learning the reward function. In addition the KL can become arbitrarily small thus the actual reward can be very small compared to it.
|
||||
|
||||
So how should you generate text for PPO training? Let's have a look!
|
||||
|
||||
## How to generate text for training?
|
||||
|
||||
In order to avoid the KL issues described above we recommend to use the following settings:
|
||||
|
||||
```python
|
||||
generation_kwargs = {
|
||||
"min_length": -1, # don't ignore the EOS token (see above)
|
||||
"top_k": 0.0, # no top-k sampling
|
||||
"top_p": 1.0, # no nucleus sampling
|
||||
"do_sample": True, # yes, we want to sample
|
||||
"pad_token_id": tokenizer.eos_token_id, # most decoder models don't have a padding token - use EOS token instead
|
||||
"max_new_tokens": 32, # specify how many tokens you want to generate at most
|
||||
}
|
||||
```
|
||||
|
||||
With these settings we usually don't encounter any issues. You can also experiment with other settings but if you encounter issues with negative KL-divergence try to go back to these and see if they persist.
|
||||
|
||||
## How can debug your own use-case?
|
||||
|
||||
Debugging the RL pipeline can be challenging due to its complexity. Here are some tips and suggestions to make the process easier:
|
||||
|
||||
- **Start from a working example**: Begin with a working example from the trl repository and gradually modify it to fit your specific use-case. Changing everything at once can make it difficult to identify the source of potential issues. For example, you can start by replacing the model in the example and once you figure out the best hyperparameters try to switch to your dataset and reward model. If you change everything at once you won't know where a potential problem comes from.
|
||||
- **Start small, scale later**: Training large models can be very slow and take several hours or days until you see any improvement. For debugging this is not a convenient timescale so try to use small model variants during the development phase and scale up once that works. That being said you sometimes have to be careful as small models might not have the capacity to solve a complicated task either.
|
||||
- **Start simple**: Try to start with a minimal example and build complexity from there. Your use-case might require for example a complicated reward function consisting of many different rewards - try to use one signal first and see if you can optimize that and then add more complexity after that.
|
||||
- **Inspect the generations**: It's always a good idea to inspect what the model is generating. Maybe there is a bug in your post-processing or your prompt. Due to bad settings you might cut-off generations too soon. These things are very hard to see on the metrics but very obvious if you look at the generations.
|
||||
- **Inspect the reward model**: If your reward is not improving over time maybe there's an issue with the reward model. You can look at extreme cases to see if it does what it should: e.g. in the sentiment case you can check if simple positive and negative examples really get different rewards. And you can look at the distribution of your dataset. Finally, maybe the reward is dominated by the query which the model can't affect so you might need to normalize this (e.g. reward of query+response minus reward of the query).
|
||||
|
||||
These are just a few tips that we find helpful - if you have more useful tricks feel free to open a PR to add them as well!
|
@ -1,106 +0,0 @@
|
||||
# Logging
|
||||
|
||||
As reinforcement learning algorithms are historically challenging to debug, it's important to pay careful attention to logging.
|
||||
By default, TRL trainers like [`PPOTrainer`] and [`GRPOTrainer`] save a lot of relevant information to supported experiment trackers like Trackio, Weights & Biases (wandb) or TensorBoard.
|
||||
|
||||
Upon initialization, pass the `report_to` argument to the respective configuration object (e.g., [`PPOConfig`] for [`PPOTrainer`], or [`GRPOConfig`] for [`GRPOTrainer`]):
|
||||
|
||||
```python
|
||||
# For PPOTrainer
|
||||
ppo_config = PPOConfig(
|
||||
# ...,
|
||||
report_to="trackio" # or "wandb" or "tensorboard"
|
||||
)
|
||||
|
||||
# For GRPOTrainer
|
||||
grpo_config = GRPOConfig(
|
||||
# ...,
|
||||
report_to="trackio" # or "wandb" or "tensorboard"
|
||||
)
|
||||
```
|
||||
|
||||
If you want to log with TensorBoard, you might also need to specify logging directories, for example, by adding `logging_dir=PATH_TO_LOGS` to the configuration object (e.g., [`PPOConfig`] or [`GRPOConfig`]).
|
||||
|
||||
## PPO Logging
|
||||
|
||||
Here's a brief explanation for the logged metrics provided in the data:
|
||||
|
||||
* `eps`: Tracks the number of episodes per second.
|
||||
* `objective/kl`: The mean Kullback-Leibler (KL) divergence between the current policy and reference policy.
|
||||
* `objective/entropy`: The mean entropy of the policy, indicating the randomness of the actions chosen by the policy.
|
||||
* `objective/non_score_reward`: The mean reward from non-score-related sources, basically `beta * kl.sum(1)`, where `beta` is the KL penalty coefficient and `kl` is the per-token KL divergence.
|
||||
* `objective/rlhf_reward`: The mean RLHF reward, which is `score - non_score_reward`.
|
||||
* `objective/scores`: The mean scores returned by the reward model / environment.
|
||||
* `policy/approxkl_avg`: The average approximate KL divergence between consecutive PPO policies. Note that this is not the same as `objective/kl`.
|
||||
* `policy/clipfrac_avg`: The average fraction of policy updates that are clipped, indicating how often the policy updates are constrained to prevent large changes.
|
||||
* `loss/policy_avg`: The average policy loss, indicating how well the policy is performing.
|
||||
* `loss/value_avg`: The average value loss, indicating the difference between the predicted value and the actual reward.
|
||||
* `val/clipfrac_avg`: The average fraction of value function updates that are clipped, similar to `policy/clipfrac_avg` but for the value function.
|
||||
* `policy/entropy_avg`: The average entropy of the policy during training, indicating how diverse the policy's actions are.
|
||||
* `val/ratio`: The mean ratio of the current policy probability to the old policy probability, providing a measure of how much the policy has changed.
|
||||
* `val/ratio_var`: The variance of the `val/ratio`, indicating the variability in policy changes.
|
||||
* `val/num_eos_tokens`: The number of end-of-sequence (EOS) tokens generated, which can indicate the number of complete responses.
|
||||
* `lr`: The current learning rate used by the optimizer.
|
||||
* `episode`: The current episode count in the training process.
|
||||
|
||||
### Crucial values
|
||||
|
||||
During training, many values are logged, here are the most important ones:
|
||||
|
||||
1. `objective/scores`: The mean scores returned by the reward model / environment.
|
||||
1. `objective/rlhf_reward`: The mean RLHF reward. This is the ultimate objective of the RLHF training. If training works as intended, this metric should keep going up.
|
||||
1. `objective/non_score_reward`: The mean reward from non-score-related sources (e.g., KL penalty).
|
||||
|
||||
Here are some parameters that are useful to monitor for stability (when these diverge or collapse to 0, try tuning variables):
|
||||
|
||||
1. `loss/value_avg`: The average value loss. It will spike / NaN when not going well.
|
||||
1. `val/ratio`: The mean ratio of the current policy probability to the old policy probability. This number should float around 1.0. If this `ratio` is too high (e.g., 2.0 or 1000.0) or too small (e.g., 0.1), it means the updates between consecutive policies are too drastic.
|
||||
1. `policy/clipfrac_avg` and `policy/approxkl_avg`: If `val/ratio` is too high, the `ratio` is going to get clipped, resulting in high `policy/clipfrac_avg` and high `policy/approxkl_avg` as well.
|
||||
1. `objective/kl`: The mean KL divergence. It should stay positive and ideally not too large, so that the policy is not too far away from the reference policy.
|
||||
|
||||
## GRPO Logging
|
||||
|
||||
Here's a brief explanation for the logged metrics provided in the data for the GRPO trainer:
|
||||
|
||||
* `num_tokens`: Total number of input tokens processed during training so far.
|
||||
|
||||
### Completions
|
||||
|
||||
* `completions/mean_length`: Mean length of all generated completions (including those not ending with an EOS token).
|
||||
* `completions/min_length`: Minimum length among all generated completions.
|
||||
* `completions/max_length`: Maximum length among all generated completions.
|
||||
* `completions/clipped_ratio`: The ratio of completions that did not end with an EOS token before reaching the maximum generation length (i.e., they were truncated).
|
||||
* `completions/mean_terminated_length`: Mean length of only those completions that successfully ended with an EOS token.
|
||||
* `completions/min_terminated_length`: Minimum length among completions that ended with an EOS token.
|
||||
* `completions/max_terminated_length`: Maximum length among completions that ended with an EOS token.
|
||||
|
||||
### Rewards
|
||||
|
||||
* `rewards/{reward_func_name}/mean`: The mean reward obtained from a specific, named reward function (e.g., `rewards/my_custom_reward/mean`). This is logged for each reward function used.
|
||||
* `rewards/{reward_func_name}/std`: The standard deviation of rewards from a specific, named reward function.
|
||||
* `reward`: The overall mean of the (potentially weighted and, if `args.scale_rewards` is true, normalized) rewards, after group-wise normalization (advantages).
|
||||
* `reward_std`: The standard deviation of the (potentially weighted) rewards *before* group-wise normalization for advantages.
|
||||
|
||||
### Policy and Loss Metrics
|
||||
|
||||
* `kl`: The mean Kullback-Leibler (KL) divergence between the current policy and the reference policy. This is logged only if `beta` (the KL coefficient in [`GRPOConfig`]) is non-zero.
|
||||
* `entropy`: Average entropy of token predictions across generated completions.
|
||||
* If Liger GRPOLoss is used (`use_liger_loss: True` in [`GRPOConfig`]):
|
||||
* `clip_ratio`: The fraction of policy updates where the probability ratio was clipped according to the GRPO loss's epsilon bounds.
|
||||
* If standard GRPOLoss is used (`use_liger_loss: False`):
|
||||
* `clip_ratio/low_mean`: The mean fraction of instances where the probability ratio `r_t(θ)` was clipped at the lower bound `1 - epsilon_low` (occurs when advantage is negative and ratio is below the bound).
|
||||
* `clip_ratio/low_min`: The minimum observed fraction for `clip_ratio/low_mean` across batches/processes.
|
||||
* `clip_ratio/high_mean`: The mean fraction of instances where the probability ratio `r_t(θ)` was clipped at the upper bound `1 + epsilon_high` (occurs when advantage is positive and ratio is above the bound).
|
||||
* `clip_ratio/high_max`: The maximum observed fraction for `clip_ratio/high_mean` across batches/processes.
|
||||
* `clip_ratio/region_mean`: The mean fraction of instances where the probability ratio was clipped at either the lower or upper bound.
|
||||
|
||||
### Crucial GRPO values
|
||||
|
||||
During GRPO training, monitor these values for insights into performance and stability:
|
||||
|
||||
* `reward`: This is the primary objective. It reflects the (group-wise normalized) rewards the policy is achieving. It should generally increase during successful training.
|
||||
* `kl`: If `beta > 0`, this tracks the divergence from the reference model. Keep an eye on it to ensure the policy doesn't stray too far, which can lead to instability.
|
||||
* `clip_ratio/*` (either `clip_ratio` for Liger loss or the more detailed `clip_ratio/...` metrics for standard loss): These indicate how often the policy updates are being constrained by the GRPO clipping mechanism. Very high values might suggest that the policy is trying to change too drastically (potentially due to large advantages or a learning rate that's too high) or that the epsilon clipping range is too restrictive.
|
||||
* `completions/clipped_ratio`: A high ratio here indicates that the model is frequently generating completions that are cut off by `max_completion_length` rather than naturally ending with an EOS token. This might suggest issues with learning sequence termination or that `max_completion_length` is too short.
|
||||
* `rewards/{reward_func_name}/mean`: Monitoring the mean of individual reward functions can help diagnose which aspects of the desired behavior the model is learning or struggling with, especially when using multiple reward sources.
|
||||
* `entropy`: Measures how uncertain the policy is in its action choices, higher entropy suggests more exploration. A collapse in entropy means the policy is becoming overconfident and deterministic, often too early. This can stall learning by reducing exploration and making updates overly biased. Stable but non-zero entropy is usually a sign that the policy retains flexibility and continues to explore.
|
@ -90,7 +90,7 @@ model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
model_name,
|
||||
peft_config=lora_config,
|
||||
reward_adapter=rm_adapter_id,
|
||||
load_in_8bit=True,
|
||||
quantization_config=BitsAndBytesConfig(load_in_8bit=True),
|
||||
)
|
||||
|
||||
...
|
||||
|
@ -91,7 +91,6 @@ trl reward --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
|
||||
- [SFT Trainer](sft_trainer) - Complete SFT guide
|
||||
- [DPO Trainer](dpo_trainer) - Preference alignment
|
||||
- [GRPO Trainer](grpo_trainer) - Group relative policy optimization
|
||||
- [Training FAQ](how_to_train) - Common questions
|
||||
|
||||
### 🚀 Scale Up
|
||||
|
||||
@ -141,4 +140,4 @@ Try adjusting the learning rate:
|
||||
training_args = SFTConfig(learning_rate=2e-5) # Good starting point
|
||||
```
|
||||
|
||||
For more help, see our [Training FAQ](how_to_train) or open an [issue on GitHub](https://github.com/huggingface/trl/issues).
|
||||
For more help, open an [issue on GitHub](https://github.com/huggingface/trl/issues).
|
||||
|
@ -2,14 +2,14 @@
|
||||
|
||||
This module contains some useful reward functions, primarily intended for use with the [`GRPOTrainer`] and [`RLOOTrainer`].
|
||||
|
||||
## Format rewards
|
||||
## accuracy_reward
|
||||
|
||||
### think_format_reward
|
||||
[[autodoc]] rewards.accuracy_reward
|
||||
|
||||
## think_format_reward
|
||||
|
||||
[[autodoc]] rewards.think_format_reward
|
||||
|
||||
## Other rewards
|
||||
|
||||
### get_soft_overlong_punishment
|
||||
## get_soft_overlong_punishment
|
||||
|
||||
[[autodoc]] rewards.get_soft_overlong_punishment
|
||||
|
@ -1,159 +0,0 @@
|
||||
# Using LLaMA models with TRL
|
||||
|
||||
We've begun rolling out examples to use Meta's LLaMA models in `trl` (see [Meta's LLaMA release](https://ai.facebook.com/blog/large-language-model-llama-meta-ai/) for the original LLaMA model).
|
||||
|
||||
## Efficient training strategies
|
||||
|
||||
Even training the smallest LLaMA model requires an enormous amount of memory. Some quick math: in bf16, every parameter uses 2 bytes (in fp32 4 bytes) in addition to 8 bytes used, e.g., in the Adam optimizer (see the [performance docs](https://huggingface.co/docs/transformers/perf_train_gpu_one#optimizer) in Transformers for more info). So a 7B parameter model would use `(2+8)*7B=70GB` just to fit in memory and would likely need more when you compute intermediate values such as attention scores. So you couldn’t train the model even on a single 80GB A100 like that. You can use some tricks, like more efficient optimizers of half-precision training, to squeeze a bit more into memory, but you’ll run out sooner or later.
|
||||
|
||||
Another option is to use Parameter-Efficient Fine-Tuning (PEFT) techniques, such as the [`peft`](https://github.com/huggingface/peft) library, which can perform low-rank adaptation (LoRA) on a model loaded in 8-bit.
|
||||
For more on `peft` + `trl`, see the [Peft integration](peft_integration) docs.
|
||||
|
||||
Loading the model in 8bit reduces the memory footprint drastically since you only need one byte per parameter for the weights (e.g. 7B LlaMa is 7GB in memory).
|
||||
Instead of training the original weights directly, LoRA adds small adapter layers on top of some specific layers (usually the attention layers); thus, the number of trainable parameters is drastically reduced.
|
||||
|
||||
In this scenario, a rule of thumb is to allocate ~1.2-1.4GB per billion parameters (depending on the batch size and sequence length) to fit the entire fine-tuning setup.
|
||||
This enables fine-tuning larger models (up to 50-60B scale models on a NVIDIA A100 80GB) at low cost.
|
||||
|
||||
Now we can fit very large models into a single GPU, but the training might still be very slow.
|
||||
The simplest strategy in this scenario is data parallelism: we replicate the same training setup into separate GPUs and pass different batches to each GPU.
|
||||
With this, you can parallelize the forward/backward passes of the model and scale with the number of GPUs.
|
||||
|
||||

|
||||
|
||||
We use either the `transformers.Trainer` or `accelerate`, which both support data parallelism without any code changes, by simply passing arguments when calling the scripts with `torchrun` or `accelerate launch`. The following runs a training script with 8 GPUs on a single machine with `accelerate` and `torchrun`, respectively.
|
||||
|
||||
```bash
|
||||
accelerate launch --multi_gpu --num_machines 1 --num_processes 8 my_accelerate_script.py
|
||||
torchrun --nnodes 1 --nproc_per_node 8 my_torch_script.py
|
||||
```
|
||||
|
||||
## Supervised fine-tuning
|
||||
|
||||
Before we start training reward models and tuning our model with RL, it helps if the model is already good in the domain we are interested in.
|
||||
In our case, we want it to answer questions, while for other use cases, we might want it to follow instructions, in which case instruction tuning is a great idea.
|
||||
The easiest way to achieve this is by continuing to train the language model with the language modeling objective on texts from the domain or task.
|
||||
The [StackExchange dataset](https://huggingface.co/datasets/HuggingFaceH4/stack-exchange-preferences) is enormous (over 10 million instructions), so we can easily train the language model on a subset of it.
|
||||
|
||||
There is nothing special about fine-tuning the model before doing RLHF - it’s just the causal language modeling objective from pretraining that we apply here.
|
||||
To use the data efficiently, we use a technique called packing: instead of having one text per sample in the batch and then padding to either the longest text or the maximal context of the model, we concatenate a lot of texts with an EOS token in between and cut chunks of the context size to fill the batch without any padding.
|
||||
|
||||

|
||||
|
||||
With this approach the training is much more efficient as each token that is passed through the model is also trained in contrast to padding tokens which are usually masked from the loss.
|
||||
If you don't have much data and are more concerned about occasionally cutting off some tokens that are overflowing the context you can also use a classical data loader.
|
||||
|
||||
```python
|
||||
# load model in 8bit
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.model_path,
|
||||
load_in_8bit=True,
|
||||
device_map={"": Accelerator().local_process_index}
|
||||
)
|
||||
model = prepare_model_for_kbit_training(model)
|
||||
|
||||
# add LoRA to model
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
lora_alpha=32,
|
||||
lora_dropout=0.05,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
|
||||
model = get_peft_model(model, config)
|
||||
```
|
||||
|
||||
We train the model for a few thousand steps with the causal language modeling objective and save the model.
|
||||
Since we will tune the model again with different objectives, we merge the adapter weights with the original model weights.
|
||||
|
||||
**Disclaimer:** due to LLaMA's license, we release only the adapter weights for this and the model checkpoints in the following sections.
|
||||
You can apply for access to the base model's weights by filling out Meta AI's [form](https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform) and then converting them to the 🤗 Transformers format by running this [script](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py).
|
||||
Note that you'll also need to install 🤗 Transformers from source until the `v4.28` is released.
|
||||
|
||||
Now that we have fine-tuned the model for the task, we are ready to train a reward model.
|
||||
|
||||
## Reward modeling and human preferences
|
||||
|
||||
In principle, we could fine-tune the model using RLHF directly with the human annotations.
|
||||
However, this would require us to send some samples to humans for rating after each optimization iteration.
|
||||
This is expensive and slow due to the number of training samples needed for convergence and the inherent latency of human reading and annotator speed.
|
||||
|
||||
A trick that works well instead of direct feedback is training a reward model on human annotations collected before the RL loop.
|
||||
The goal of the reward model is to imitate how a human would rate a text. There are several possible strategies to build a reward model: the most straightforward way would be to predict the annotation (e.g. a rating score or a binary value for “good”/”bad”).
|
||||
In practice, what works better is to predict the ranking of two examples, where the reward model is presented with two candidates `(y_k, y_j)` for a given prompt `x` and has to predict which one would be rated higher by a human annotator.
|
||||
|
||||
With the StackExchange dataset, we can infer which of the two answers was preferred by the users based on the score.
|
||||
With that information and the loss defined above, we can then modify the `transformers.Trainer` by adding a custom loss function.
|
||||
|
||||
```python
|
||||
class RewardTrainer(Trainer):
|
||||
def compute_loss(self, model, inputs, return_outputs=False):
|
||||
rewards_j = model(input_ids=inputs["input_ids_j"], attention_mask=inputs["attention_mask_j"])[0]
|
||||
rewards_k = model(input_ids=inputs["input_ids_k"], attention_mask=inputs["attention_mask_k"])[0]
|
||||
loss = -nn.functional.logsigmoid(rewards_j - rewards_k).mean()
|
||||
if return_outputs:
|
||||
return loss, {"rewards_j": rewards_j, "rewards_k": rewards_k}
|
||||
return loss
|
||||
```
|
||||
|
||||
We utilize a subset of a 100,000 pair of candidates and evaluate on a held-out set of 50,000. With a modest training batch size of 4, we train the Llama model using the LoRA `peft` adapter for a single epoch using the Adam optimizer with BF16 precision. Our LoRA configuration is:
|
||||
|
||||
```python
|
||||
peft_config = LoraConfig(
|
||||
task_type=TaskType.SEQ_CLS,
|
||||
inference_mode=False,
|
||||
r=8,
|
||||
lora_alpha=32,
|
||||
lora_dropout=0.1,
|
||||
)
|
||||
```
|
||||
|
||||
As detailed in the next section, the resulting adapter can be merged into the frozen model and saved for further downstream use.
|
||||
|
||||
## Reinforcement Learning from Human Feedback
|
||||
|
||||
With the fine-tuned language model and the reward model at hand, we are now ready to run the RL loop. It follows roughly three steps:
|
||||
|
||||
1. Generate responses from prompts,
|
||||
2. Rate the responses with the reward model,
|
||||
3. Run a reinforcement learning policy-optimization step with the ratings.
|
||||
|
||||
The Query and Response prompts are templated as follows before being tokenized and passed to the model:
|
||||
|
||||
```bash
|
||||
Question: <Query>
|
||||
|
||||
Answer: <Response>
|
||||
```
|
||||
|
||||
The same template was used for SFT, RM and RLHF stages.
|
||||
Once more, we utilize `peft` for memory-efficient training, which offers an extra advantage in the RLHF context.
|
||||
Here, the reference model and policy share the same base, the SFT model, which we load in 8-bit and freeze during training.
|
||||
We exclusively optimize the policy's LoRA weights using PPO while sharing the base model's weights.
|
||||
|
||||
```python
|
||||
for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
|
||||
question_tensors = batch["input_ids"]
|
||||
|
||||
# sample from the policy and to generate responses
|
||||
response_tensors = ppo_trainer.generate(
|
||||
question_tensors,
|
||||
return_prompt=False,
|
||||
length_sampler=output_length_sampler,
|
||||
**generation_kwargs,
|
||||
)
|
||||
batch["response"] = tokenizer.batch_decode(response_tensors, skip_special_tokens=True)
|
||||
|
||||
# Compute sentiment score
|
||||
texts = [q + r for q, r in zip(batch["query"], batch["response"])]
|
||||
pipe_outputs = sentiment_pipe(texts, **sent_kwargs)
|
||||
rewards = [torch.tensor(output[0]["score"] - script_args.reward_baseline) for output in pipe_outputs]
|
||||
|
||||
# Run PPO step
|
||||
stats = ppo_trainer.step(question_tensors, response_tensors, rewards)
|
||||
# Log stats to Wandb
|
||||
ppo_trainer.log_stats(stats, batch, rewards)
|
||||
```
|
||||
|
||||
For the rest of the details and evaluation, please refer to our [blog post on StackLLaMA](https://huggingface.co/blog/stackllama).
|
@ -70,8 +70,6 @@ import os
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from latex2sympy2_extended import NormalizationConfig
|
||||
from math_verify import LatexExtractionConfig, parse, verify
|
||||
|
||||
from trl import (
|
||||
GRPOConfig,
|
||||
@ -83,7 +81,7 @@ from trl import (
|
||||
get_peft_config,
|
||||
get_quantization_config,
|
||||
)
|
||||
from trl.rewards import think_format_reward
|
||||
from trl.rewards import accuracy_reward, think_format_reward
|
||||
|
||||
|
||||
# Enable logging in a Hugging Face Space
|
||||
@ -149,54 +147,6 @@ if __name__ == "__main__":
|
||||
train_dataset = dataset["train"]
|
||||
eval_dataset = dataset["test"] if training_args.eval_strategy != "no" else None
|
||||
|
||||
################
|
||||
# Reward Function for Training
|
||||
################
|
||||
def accuracy_reward(completions, solution: list[str], **kwargs):
|
||||
"""Reward function that checks if the completion matches the ground truth.
|
||||
- If both gold and prediction are parseable → use math verification.
|
||||
- If not parseable → compare as normalized text.
|
||||
"""
|
||||
rewards = []
|
||||
contents = [completion[0]["content"] for completion in completions]
|
||||
for content, sol in zip(contents, solution):
|
||||
try:
|
||||
gold_parsed = parse(sol, extraction_mode="first_match")
|
||||
except Exception:
|
||||
gold_parsed = []
|
||||
|
||||
if len(gold_parsed) != 0:
|
||||
# Try parsing predicted answer too
|
||||
try:
|
||||
answer_parsed = parse(
|
||||
content,
|
||||
extraction_config=[
|
||||
LatexExtractionConfig(
|
||||
normalization_config=NormalizationConfig(
|
||||
nits=False,
|
||||
malformed_operators=False,
|
||||
basic_latex=True,
|
||||
boxed="all",
|
||||
units=True,
|
||||
),
|
||||
boxed_match_priority=0,
|
||||
try_extract_without_anchor=False,
|
||||
)
|
||||
],
|
||||
extraction_mode="first_match",
|
||||
)
|
||||
reward = float(verify(gold_parsed, answer_parsed))
|
||||
except Exception as e:
|
||||
print(f"verify failed: {e}, answer: {content}, gold: {sol}")
|
||||
reward = None
|
||||
else:
|
||||
# fallback to text match
|
||||
reward = float(content.strip().lower() == sol.strip().lower())
|
||||
|
||||
rewards.append(reward)
|
||||
|
||||
return rewards
|
||||
|
||||
################
|
||||
# Training
|
||||
################
|
||||
|
@ -57,8 +57,6 @@ import os
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from latex2sympy2_extended import NormalizationConfig
|
||||
from math_verify import LatexExtractionConfig, parse, verify
|
||||
|
||||
from trl import (
|
||||
GRPOConfig,
|
||||
@ -70,7 +68,7 @@ from trl import (
|
||||
get_peft_config,
|
||||
get_quantization_config,
|
||||
)
|
||||
from trl.rewards import think_format_reward
|
||||
from trl.rewards import accuracy_reward, think_format_reward
|
||||
|
||||
|
||||
# Enable logging in a Hugging Face Space
|
||||
@ -120,54 +118,6 @@ if __name__ == "__main__":
|
||||
train_dataset = train_dataset.remove_columns(["messages", "problem"])
|
||||
eval_dataset = eval_dataset.remove_columns(["messages", "problem"])
|
||||
|
||||
################
|
||||
# Reward Function for Training
|
||||
################
|
||||
def accuracy_reward(completions, solution: list[str], **kwargs):
|
||||
"""Reward function that checks if the completion matches the ground truth.
|
||||
- If both gold and prediction are parseable → use math verification.
|
||||
- If not parseable → compare as normalized text.
|
||||
"""
|
||||
rewards = []
|
||||
contents = [completion[0]["content"] for completion in completions]
|
||||
for content, sol in zip(contents, solution):
|
||||
try:
|
||||
gold_parsed = parse(sol, extraction_mode="first_match")
|
||||
except Exception:
|
||||
gold_parsed = []
|
||||
|
||||
if len(gold_parsed) != 0:
|
||||
# Try parsing predicted answer too
|
||||
try:
|
||||
answer_parsed = parse(
|
||||
content,
|
||||
extraction_config=[
|
||||
LatexExtractionConfig(
|
||||
normalization_config=NormalizationConfig(
|
||||
nits=False,
|
||||
malformed_operators=False,
|
||||
basic_latex=True,
|
||||
boxed="all",
|
||||
units=True,
|
||||
),
|
||||
boxed_match_priority=0,
|
||||
try_extract_without_anchor=False,
|
||||
)
|
||||
],
|
||||
extraction_mode="first_match",
|
||||
)
|
||||
reward = float(verify(gold_parsed, answer_parsed))
|
||||
except Exception as e:
|
||||
print(f"verify failed: {e}, answer: {content}, gold: {sol}")
|
||||
reward = None
|
||||
else:
|
||||
# fallback to text match
|
||||
reward = float(content.strip().lower() == sol.strip().lower())
|
||||
|
||||
rewards.append(reward)
|
||||
|
||||
return rewards
|
||||
|
||||
################
|
||||
# Training
|
||||
################
|
||||
|
@ -57,8 +57,6 @@ import os
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from latex2sympy2_extended import NormalizationConfig
|
||||
from math_verify import LatexExtractionConfig, parse, verify
|
||||
|
||||
from trl import (
|
||||
GRPOConfig,
|
||||
@ -70,7 +68,7 @@ from trl import (
|
||||
get_peft_config,
|
||||
get_quantization_config,
|
||||
)
|
||||
from trl.rewards import think_format_reward
|
||||
from trl.rewards import accuracy_reward, think_format_reward
|
||||
|
||||
|
||||
# Enable logging in a Hugging Face Space
|
||||
@ -136,54 +134,6 @@ if __name__ == "__main__":
|
||||
train_dataset = dataset["train"]
|
||||
eval_dataset = dataset["test"] if training_args.eval_strategy != "no" else None
|
||||
|
||||
################
|
||||
# Reward Function for Training
|
||||
################
|
||||
def accuracy_reward(completions, solution: list[str], **kwargs):
|
||||
"""Reward function that checks if the completion matches the ground truth.
|
||||
- If both gold and prediction are parseable → use math verification.
|
||||
- If not parseable → compare as normalized text.
|
||||
"""
|
||||
rewards = []
|
||||
contents = [completion[0]["content"] for completion in completions]
|
||||
for content, sol in zip(contents, solution):
|
||||
try:
|
||||
gold_parsed = parse(sol, extraction_mode="first_match")
|
||||
except Exception:
|
||||
gold_parsed = []
|
||||
|
||||
if len(gold_parsed) != 0:
|
||||
# Try parsing predicted answer too
|
||||
try:
|
||||
answer_parsed = parse(
|
||||
content,
|
||||
extraction_config=[
|
||||
LatexExtractionConfig(
|
||||
normalization_config=NormalizationConfig(
|
||||
nits=False,
|
||||
malformed_operators=False,
|
||||
basic_latex=True,
|
||||
boxed="all",
|
||||
units=True,
|
||||
),
|
||||
boxed_match_priority=0,
|
||||
try_extract_without_anchor=False,
|
||||
)
|
||||
],
|
||||
extraction_mode="first_match",
|
||||
)
|
||||
reward = float(verify(gold_parsed, answer_parsed))
|
||||
except Exception as e:
|
||||
print(f"verify failed: {e}, answer: {content}, gold: {sol}")
|
||||
reward = None
|
||||
else:
|
||||
# fallback to text match
|
||||
reward = float(content.strip().lower() == sol.strip().lower())
|
||||
|
||||
rewards.append(reward)
|
||||
|
||||
return rewards
|
||||
|
||||
################
|
||||
# Training
|
||||
################
|
||||
|
@ -87,8 +87,6 @@ import os
|
||||
import torch
|
||||
import transformers
|
||||
from datasets import load_dataset
|
||||
from latex2sympy2_extended import NormalizationConfig
|
||||
from math_verify import LatexExtractionConfig, parse, verify
|
||||
from transformers import AutoConfig, AutoProcessor, GenerationConfig
|
||||
|
||||
from trl import (
|
||||
@ -102,7 +100,7 @@ from trl import (
|
||||
get_peft_config,
|
||||
get_quantization_config,
|
||||
)
|
||||
from trl.rewards import think_format_reward
|
||||
from trl.rewards import accuracy_reward, think_format_reward
|
||||
|
||||
|
||||
# Enable logging in a Hugging Face Space
|
||||
@ -192,54 +190,6 @@ if __name__ == "__main__":
|
||||
train_dataset = dataset["train"]
|
||||
eval_dataset = dataset["test"] if training_args.eval_strategy != "no" else None
|
||||
|
||||
################
|
||||
# Reward Function for Training (same as GRPO VLM)
|
||||
################
|
||||
def accuracy_reward(completions, solution: list[str], **kwargs):
|
||||
"""Reward function that checks if the completion matches the ground truth.
|
||||
- If both gold and prediction are parseable → use math verification.
|
||||
- If not parseable → compare as normalized text.
|
||||
"""
|
||||
rewards = []
|
||||
contents = [completion[0]["content"] for completion in completions]
|
||||
for content, sol in zip(contents, solution):
|
||||
try:
|
||||
gold_parsed = parse(sol, extraction_mode="first_match")
|
||||
except Exception:
|
||||
gold_parsed = []
|
||||
|
||||
if len(gold_parsed) != 0:
|
||||
# Try parsing predicted answer too
|
||||
try:
|
||||
answer_parsed = parse(
|
||||
content,
|
||||
extraction_config=[
|
||||
LatexExtractionConfig(
|
||||
normalization_config=NormalizationConfig(
|
||||
nits=False,
|
||||
malformed_operators=False,
|
||||
basic_latex=True,
|
||||
boxed="all",
|
||||
units=True,
|
||||
),
|
||||
boxed_match_priority=0,
|
||||
try_extract_without_anchor=False,
|
||||
)
|
||||
],
|
||||
extraction_mode="first_match",
|
||||
)
|
||||
reward = float(verify(gold_parsed, answer_parsed))
|
||||
except Exception as e:
|
||||
print(f"verify failed: {e}, answer: {content}, gold: {sol}")
|
||||
reward = None
|
||||
else:
|
||||
# fallback to text match
|
||||
reward = float(content.strip().lower() == sol.strip().lower())
|
||||
|
||||
rewards.append(reward)
|
||||
|
||||
return rewards
|
||||
|
||||
################
|
||||
# Training
|
||||
################
|
||||
|
@ -33,12 +33,10 @@ import os
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from latex2sympy2_extended import NormalizationConfig
|
||||
from math_verify import LatexExtractionConfig, parse, verify
|
||||
from peft import LoraConfig
|
||||
|
||||
from trl import RLOOConfig, RLOOTrainer
|
||||
from trl.rewards import think_format_reward
|
||||
from trl.rewards import accuracy_reward, think_format_reward
|
||||
|
||||
|
||||
# Enable logging in a Hugging Face Space
|
||||
@ -67,52 +65,6 @@ def main():
|
||||
train_dataset = train_dataset.map(make_conversation, remove_columns=["messages", "problem"])
|
||||
eval_dataset = eval_dataset.map(make_conversation, remove_columns=["messages", "problem"])
|
||||
|
||||
# Reward function for training
|
||||
def accuracy_reward(completions, solution: list[str], **kwargs):
|
||||
"""Reward function that checks if the completion matches the ground truth.
|
||||
- If both gold and prediction are parseable → use math verification.
|
||||
- If not parseable → compare as normalized text.
|
||||
"""
|
||||
rewards = []
|
||||
contents = [completion[0]["content"] for completion in completions]
|
||||
for content, sol in zip(contents, solution):
|
||||
try:
|
||||
gold_parsed = parse(sol, extraction_mode="first_match")
|
||||
except Exception:
|
||||
gold_parsed = []
|
||||
|
||||
if len(gold_parsed) != 0:
|
||||
# Try parsing predicted answer too
|
||||
try:
|
||||
answer_parsed = parse(
|
||||
content,
|
||||
extraction_config=[
|
||||
LatexExtractionConfig(
|
||||
normalization_config=NormalizationConfig(
|
||||
nits=False,
|
||||
malformed_operators=False,
|
||||
basic_latex=True,
|
||||
boxed="all",
|
||||
units=True,
|
||||
),
|
||||
boxed_match_priority=0,
|
||||
try_extract_without_anchor=False,
|
||||
)
|
||||
],
|
||||
extraction_mode="first_match",
|
||||
)
|
||||
reward = float(verify(gold_parsed, answer_parsed))
|
||||
except Exception as e:
|
||||
print(f"verify failed: {e}, answer: {content}, gold: {sol}")
|
||||
reward = None
|
||||
else:
|
||||
# fallback to text match
|
||||
reward = float(content.strip().lower() == sol.strip().lower())
|
||||
|
||||
rewards.append(reward)
|
||||
|
||||
return rewards
|
||||
|
||||
# Training
|
||||
training_args = RLOOConfig(
|
||||
output_dir="Qwen3-0.6B-RLOO",
|
||||
|
@ -70,8 +70,6 @@ import os
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from latex2sympy2_extended import NormalizationConfig
|
||||
from math_verify import LatexExtractionConfig, parse, verify
|
||||
|
||||
from trl import (
|
||||
ModelConfig,
|
||||
@ -83,7 +81,7 @@ from trl import (
|
||||
get_peft_config,
|
||||
get_quantization_config,
|
||||
)
|
||||
from trl.rewards import think_format_reward
|
||||
from trl.rewards import accuracy_reward, think_format_reward
|
||||
|
||||
|
||||
# Enable logging in a Hugging Face Space
|
||||
@ -149,54 +147,6 @@ if __name__ == "__main__":
|
||||
train_dataset = dataset["train"]
|
||||
eval_dataset = dataset["test"] if training_args.eval_strategy != "no" else None
|
||||
|
||||
################
|
||||
# Reward Function for Training
|
||||
################
|
||||
def accuracy_reward(completions, solution: list[str], **kwargs):
|
||||
"""Reward function that checks if the completion matches the ground truth.
|
||||
- If both gold and prediction are parseable → use math verification.
|
||||
- If not parseable → compare as normalized text.
|
||||
"""
|
||||
rewards = []
|
||||
contents = [completion[0]["content"] for completion in completions]
|
||||
for content, sol in zip(contents, solution):
|
||||
try:
|
||||
gold_parsed = parse(sol, extraction_mode="first_match")
|
||||
except Exception:
|
||||
gold_parsed = []
|
||||
|
||||
if len(gold_parsed) != 0:
|
||||
# Try parsing predicted answer too
|
||||
try:
|
||||
answer_parsed = parse(
|
||||
content,
|
||||
extraction_config=[
|
||||
LatexExtractionConfig(
|
||||
normalization_config=NormalizationConfig(
|
||||
nits=False,
|
||||
malformed_operators=False,
|
||||
basic_latex=True,
|
||||
boxed="all",
|
||||
units=True,
|
||||
),
|
||||
boxed_match_priority=0,
|
||||
try_extract_without_anchor=False,
|
||||
)
|
||||
],
|
||||
extraction_mode="first_match",
|
||||
)
|
||||
reward = float(verify(gold_parsed, answer_parsed))
|
||||
except Exception as e:
|
||||
print(f"verify failed: {e}, answer: {content}, gold: {sol}")
|
||||
reward = None
|
||||
else:
|
||||
# fallback to text match
|
||||
reward = float(content.strip().lower() == sol.strip().lower())
|
||||
|
||||
rewards.append(reward)
|
||||
|
||||
return rewards
|
||||
|
||||
################
|
||||
# Training
|
||||
################
|
||||
|
@ -89,6 +89,9 @@ vlm = [
|
||||
"torchvision",
|
||||
"num2words==0.5.14"
|
||||
]
|
||||
math_verify = [
|
||||
"math-verify>=0.5.2",
|
||||
]
|
||||
dev = [
|
||||
# bco
|
||||
"scikit-learn",
|
||||
|
@ -329,11 +329,11 @@ class TestGRPOTrainerSlow(TrlTestCase):
|
||||
assert lora_params_changed, "No LoRA parameters were updated during training."
|
||||
|
||||
except torch.OutOfMemoryError as e:
|
||||
self.skipTest(f"Skipping VLM training test due to insufficient GPU memory: {e}")
|
||||
pytest.skip(f"Skipping VLM training test due to insufficient GPU memory: {e}")
|
||||
except Exception as e:
|
||||
# Check for other memory-related errors
|
||||
if any(keyword in str(e).lower() for keyword in ["memory", "cuda", "out of memory", "insufficient"]):
|
||||
self.skipTest(f"Skipping VLM training test due to hardware constraints: {e}")
|
||||
pytest.skip(f"Skipping VLM training test due to hardware constraints: {e}")
|
||||
else:
|
||||
raise
|
||||
|
||||
@ -474,11 +474,11 @@ class TestGRPOTrainerSlow(TrlTestCase):
|
||||
"decrease gpu memory",
|
||||
]
|
||||
):
|
||||
self.skipTest(f"Skipping vLLM colocate test due to hardware constraints: {e}")
|
||||
pytest.skip(f"Skipping vLLM colocate test due to hardware constraints: {e}")
|
||||
elif "KeyError" in str(e) and "RANK" in str(e):
|
||||
self.skipTest(f"Skipping vLLM colocate test due to environment setup issues: {e}")
|
||||
pytest.skip(f"Skipping vLLM colocate test due to environment setup issues: {e}")
|
||||
elif "ValueError" in str(e) and "memory" in str(e).lower():
|
||||
self.skipTest(f"Skipping vLLM colocate test due to memory constraints: {e}")
|
||||
pytest.skip(f"Skipping vLLM colocate test due to memory constraints: {e}")
|
||||
else:
|
||||
raise
|
||||
finally:
|
||||
@ -541,11 +541,11 @@ class TestGRPOTrainerSlow(TrlTestCase):
|
||||
"decrease gpu memory",
|
||||
]
|
||||
):
|
||||
self.skipTest(f"Skipping vLLM training test due to hardware constraints: {e}")
|
||||
pytest.skip(f"Skipping vLLM training test due to hardware constraints: {e}")
|
||||
elif "KeyError" in str(e) and "RANK" in str(e):
|
||||
self.skipTest(f"Skipping vLLM training test due to environment setup issues: {e}")
|
||||
pytest.skip(f"Skipping vLLM training test due to environment setup issues: {e}")
|
||||
elif "ValueError" in str(e) and "memory" in str(e).lower():
|
||||
self.skipTest(f"Skipping vLLM training test due to memory constraints: {e}")
|
||||
pytest.skip(f"Skipping vLLM training test due to memory constraints: {e}")
|
||||
else:
|
||||
raise
|
||||
|
||||
|
@ -642,6 +642,7 @@ class TestDPOTrainer(TrlTestCase):
|
||||
def test_dpo_lora_bf16_autocast_llama(self):
|
||||
# Note this test only works on compute capability > 7 GPU devices
|
||||
from peft import LoraConfig
|
||||
from transformers import BitsAndBytesConfig
|
||||
|
||||
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
@ -655,7 +656,9 @@ class TestDPOTrainer(TrlTestCase):
|
||||
)
|
||||
|
||||
# lora model
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, load_in_4bit=True)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id, quantization_config=BitsAndBytesConfig(load_in_4bit=True)
|
||||
)
|
||||
|
||||
training_args = DPOConfig(
|
||||
output_dir=self.tmp_dir,
|
||||
@ -725,6 +728,7 @@ class TestDPOTrainer(TrlTestCase):
|
||||
)
|
||||
def test_dpo_lora_bf16_autocast(self, loss_type, pre_compute, gen_during_eval):
|
||||
from peft import LoraConfig
|
||||
from transformers import BitsAndBytesConfig
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
@ -735,7 +739,9 @@ class TestDPOTrainer(TrlTestCase):
|
||||
)
|
||||
|
||||
# lora model
|
||||
model = AutoModelForCausalLM.from_pretrained(self.model_id, load_in_4bit=True)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_id, quantization_config=BitsAndBytesConfig(load_in_4bit=True)
|
||||
)
|
||||
|
||||
training_args = DPOConfig(
|
||||
output_dir=self.tmp_dir,
|
||||
@ -1416,6 +1422,7 @@ class TestDPOVisionTrainer(TrlTestCase):
|
||||
# ("trl-internal-testing/tiny-PaliGemmaForConditionalGeneration",),
|
||||
("trl-internal-testing/tiny-LlavaForConditionalGeneration",),
|
||||
("trl-internal-testing/tiny-LlavaNextForConditionalGeneration",),
|
||||
("trl-internal-testing/tiny-Gemma3ForConditionalGeneration",),
|
||||
]
|
||||
)
|
||||
def test_vdpo_trainer(self, model_id):
|
||||
|
@ -259,7 +259,7 @@ class TestGKDTrainer(TrlTestCase):
|
||||
|
||||
# Ensure liger fused JSD path is enabled; if not, skip (runtime may lack system libs)
|
||||
if not getattr(trainer, "use_liger_gkd_loss", False):
|
||||
self.skipTest("Liger fused JSD not enabled at runtime; skipping fused-loss assertion")
|
||||
pytest.skip("Liger fused JSD not enabled at runtime; skipping fused-loss assertion")
|
||||
|
||||
trainer.train()
|
||||
|
||||
|
@ -101,9 +101,12 @@ class TestPeftModel(TrlTestCase):
|
||||
Simply creates a peft model and checks that it can be loaded.
|
||||
"""
|
||||
from bitsandbytes.nn import Linear8bitLt
|
||||
from transformers import BitsAndBytesConfig
|
||||
|
||||
trl_model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
self.causal_lm_model_id, peft_config=self.lora_config, load_in_8bit=True
|
||||
self.causal_lm_model_id,
|
||||
peft_config=self.lora_config,
|
||||
quantization_config=BitsAndBytesConfig(load_in_8bit=True),
|
||||
)
|
||||
# Check that the number of trainable parameters is correct
|
||||
nb_trainable_params = sum(p.numel() for p in trl_model.parameters() if p.requires_grad)
|
||||
@ -111,7 +114,7 @@ class TestPeftModel(TrlTestCase):
|
||||
assert isinstance(trl_model.pretrained_model.model.model.layers[0].mlp.gate_proj, Linear8bitLt)
|
||||
|
||||
causal_lm_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.causal_lm_model_id, load_in_8bit=True, device_map="auto"
|
||||
self.causal_lm_model_id, quantization_config=BitsAndBytesConfig(load_in_8bit=True), device_map="auto"
|
||||
)
|
||||
trl_model = AutoModelForCausalLMWithValueHead.from_pretrained(causal_lm_model, peft_config=self.lora_config)
|
||||
# Check that the number of trainable parameters is correct
|
||||
|
@ -13,9 +13,9 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from trl.rewards import get_soft_overlong_punishment, think_format_reward
|
||||
from trl.rewards import accuracy_reward, get_soft_overlong_punishment, think_format_reward
|
||||
|
||||
from .testing_utils import TrlTestCase
|
||||
from .testing_utils import TrlTestCase, require_math_latex
|
||||
|
||||
|
||||
class TestThinkFormatReward(TrlTestCase):
|
||||
@ -85,3 +85,60 @@ class TestSoftOverlongPunishmentReward:
|
||||
completion_ids = [[1] * 90] # 90 is between 80 and 100
|
||||
rewards = reward_fn(completion_ids)
|
||||
assert round(abs(rewards[0] - -0.5), 4) == 0
|
||||
|
||||
|
||||
class TestAccuracyReward:
|
||||
@require_math_latex
|
||||
def test_accuracy_reward_correct_answer(self):
|
||||
"""Test accuracy_reward with a correct answer."""
|
||||
completion = [[{"content": r"\boxed{\frac{63}{400}}"}], [{"content": r"\boxed{\frac{63}{400}}"}]]
|
||||
solution = [r"\frac{63}{400}", "63/400"]
|
||||
rewards = accuracy_reward(completion, solution)
|
||||
assert rewards[0] == 1.0
|
||||
assert rewards[1] == 1.0
|
||||
|
||||
@require_math_latex
|
||||
def test_accuracy_reward_wrong_answer(self):
|
||||
"""Test accuracy_reward with an incorrect answer."""
|
||||
completion = [[{"content": r"\boxed{\frac{64}{400}}"}]]
|
||||
solution = [r"\frac{63}{400}"]
|
||||
rewards = accuracy_reward(completion, solution)
|
||||
assert rewards[0] == 0.0
|
||||
|
||||
@require_math_latex
|
||||
def test_accuracy_reward_wrong_answer_no_latex(self):
|
||||
"""Test accuracy_reward with an incorrect answer and gold solution with no latex."""
|
||||
completion = [[{"content": r"\boxed{3}"}]]
|
||||
solution = ["6"]
|
||||
rewards = accuracy_reward(completion, solution)
|
||||
assert rewards[0] == 0.0
|
||||
|
||||
@require_math_latex
|
||||
def test_accuracy_reward_unparseable_gold(self):
|
||||
"""Test accuracy_reward with an unparseable gold solution."""
|
||||
completion = [
|
||||
[{"content": "Answer is forty two."}],
|
||||
[{"content": "Some other content."}],
|
||||
[{"content": r"Answer is \boxed{42}."}],
|
||||
[{"content": r"Answer is \boxed{\mathbf{42}}."}], # Make response bold
|
||||
[{"content": r"Answer is \boxed{\textbf{42}}."}], # Different latex command for bold
|
||||
[{"content": r"Answer is \boxed{42}."}],
|
||||
[{"content": r"Answer is \boxed{42.3456}."}],
|
||||
]
|
||||
solution = [
|
||||
"Answer is forty two.",
|
||||
"Answer is forty three.",
|
||||
"Answer is 42.0", # Decimal point
|
||||
"Answer is 42 43 okay?", # Extra space
|
||||
"Answer is 42",
|
||||
r"Answer is \n\boxed{42}", # Newline in gold solution
|
||||
"Answer is 42.34560", # Extra trailing zero
|
||||
]
|
||||
rewards = accuracy_reward(completion, solution)
|
||||
assert rewards[0] == 1.0 # Should revert to exact text match
|
||||
assert rewards[1] == 0.0
|
||||
assert rewards[2] == 1.0
|
||||
assert rewards[3] == 1.0
|
||||
assert rewards[4] == 1.0
|
||||
assert rewards[5] == 1.0
|
||||
assert rewards[6] == 1.0 # Should ignore trailing zeros
|
||||
|
@ -63,4 +63,3 @@ class TestRichProgressCallback(TrlTestCase):
|
||||
)
|
||||
|
||||
trainer.train()
|
||||
trainer.train()
|
||||
|
@ -1477,6 +1477,7 @@ class TestSFTTrainer(TrlTestCase):
|
||||
# To ensure coverage, we run tests on the full model but mark them as slow to exclude from default runs.
|
||||
@pytest.mark.slow
|
||||
@require_vision
|
||||
@pytest.mark.skip(reason="Model google/gemma-3n-E2B-it is gated and requires HF token")
|
||||
def test_train_vlm_gemma_3n(self):
|
||||
# Get the dataset
|
||||
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_language_modeling", split="train")
|
||||
|
@ -22,7 +22,7 @@ from transformers.testing_utils import require_torch_multi_accelerator, torch_de
|
||||
from trl.extras.vllm_client import VLLMClient
|
||||
from trl.scripts.vllm_serve import chunk_list
|
||||
|
||||
from .testing_utils import TrlTestCase, kill_process, require_3_accelerators
|
||||
from .testing_utils import TrlTestCase, kill_process, require_3_accelerators, require_vllm
|
||||
|
||||
|
||||
class TestChunkList(TrlTestCase):
|
||||
@ -53,6 +53,7 @@ class TestChunkList(TrlTestCase):
|
||||
|
||||
@pytest.mark.slow
|
||||
@require_torch_multi_accelerator
|
||||
@require_vllm
|
||||
class TestVLLMClientServer(TrlTestCase):
|
||||
model_id = "Qwen/Qwen2.5-1.5B"
|
||||
|
||||
@ -212,6 +213,7 @@ class TestVLLMClientServerBaseURL(TrlTestCase):
|
||||
|
||||
@pytest.mark.slow
|
||||
@require_3_accelerators
|
||||
@require_vllm
|
||||
class TestVLLMClientServerTP(TrlTestCase):
|
||||
model_id = "Qwen/Qwen2.5-1.5B"
|
||||
|
||||
@ -274,6 +276,7 @@ class TestVLLMClientServerTP(TrlTestCase):
|
||||
|
||||
@pytest.mark.slow
|
||||
@require_3_accelerators
|
||||
@require_vllm
|
||||
class TestVLLMClientServerDP(TrlTestCase):
|
||||
model_id = "Qwen/Qwen2.5-1.5B"
|
||||
|
||||
@ -336,6 +339,7 @@ class TestVLLMClientServerDP(TrlTestCase):
|
||||
|
||||
@pytest.mark.slow
|
||||
@require_torch_multi_accelerator
|
||||
@require_vllm
|
||||
class TestVLLMClientServerDeviceParameter(TrlTestCase):
|
||||
"""Test the device parameter functionality in init_communicator."""
|
||||
|
||||
|
@ -26,12 +26,19 @@ from transformers.testing_utils import torch_device
|
||||
from transformers.utils import is_peft_available, is_rich_available, is_vision_available
|
||||
|
||||
from trl import BaseBinaryJudge, BasePairwiseJudge
|
||||
from trl.import_utils import is_joblib_available, is_llm_blender_available, is_mergekit_available, is_vllm_available
|
||||
from trl.import_utils import (
|
||||
is_joblib_available,
|
||||
is_llm_blender_available,
|
||||
is_math_verify_available,
|
||||
is_mergekit_available,
|
||||
is_vllm_available,
|
||||
)
|
||||
|
||||
|
||||
require_bitsandbytes = pytest.mark.skipif(not is_bitsandbytes_available(), reason="test requires bitsandbytes")
|
||||
require_comet = pytest.mark.skipif(not is_comet_available(), reason="test requires comet_ml")
|
||||
require_llm_blender = pytest.mark.skipif(not is_llm_blender_available(), reason="test requires llm-blender")
|
||||
require_math_latex = pytest.mark.skipif(not is_math_verify_available(), reason="test requires math_verify")
|
||||
require_mergekit = pytest.mark.skipif(not is_mergekit_available(), reason="test requires mergekit")
|
||||
require_peft = pytest.mark.skipif(not is_peft_available(), reason="test requires peft")
|
||||
require_rich = pytest.mark.skipif(not is_rich_available(), reason="test requires rich")
|
||||
|
@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import warnings
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
@ -42,8 +43,16 @@ class BestOfNSampler:
|
||||
generation_config ([`~transformers.GenerationConfig`], *optional*):
|
||||
Generation config passed to the underlying model's `generate` method. See
|
||||
[`~transformers.GenerationConfig`] for more details.
|
||||
|
||||
<Deprecated version="0.24.0">
|
||||
|
||||
`BestOfNSampler` is deprecated and will be removed in version 0.25.
|
||||
|
||||
</Deprecated>
|
||||
"""
|
||||
|
||||
warnings.warn("`BestOfNSampler` is deprecated and will be removed in TRL 0.25.", FutureWarning, stacklevel=2)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: PreTrainedModelWrapper,
|
||||
|
@ -31,6 +31,7 @@ _fastapi_available = _is_package_available("fastapi")
|
||||
_joblib_available = _is_package_available("joblib")
|
||||
_liger_kernel_available, _liger_kernel_version = _is_package_available("liger_kernel", return_version=True)
|
||||
_llm_blender_available = _is_package_available("llm_blender")
|
||||
_math_verify_available = _is_package_available("math_verify")
|
||||
_mergekit_available = _is_package_available("mergekit")
|
||||
_pydantic_available = _is_package_available("pydantic")
|
||||
_requests_available = _is_package_available("requests")
|
||||
@ -61,6 +62,10 @@ def is_llm_blender_available() -> bool:
|
||||
return _llm_blender_available
|
||||
|
||||
|
||||
def is_math_verify_available() -> bool:
|
||||
return _math_verify_available
|
||||
|
||||
|
||||
def is_mergekit_available() -> bool:
|
||||
return _mergekit_available
|
||||
|
||||
|
@ -153,8 +153,13 @@ class PreTrainedModelWrapper(nn.Module):
|
||||
|
||||
current_device = cls._get_current_device()
|
||||
if isinstance(pretrained_model_name_or_path, str):
|
||||
is_loaded_in_8bit = pretrained_kwargs["load_in_8bit"] if "load_in_8bit" in pretrained_kwargs else False
|
||||
is_loaded_in_4bit = pretrained_kwargs["load_in_4bit"] if "load_in_4bit" in pretrained_kwargs else False
|
||||
quantization_config = pretrained_kwargs.get("quantization_config", None)
|
||||
if quantization_config is not None:
|
||||
is_loaded_in_8bit = getattr(quantization_config, "load_in_8bit", False)
|
||||
is_loaded_in_4bit = getattr(quantization_config, "load_in_4bit", False)
|
||||
else:
|
||||
is_loaded_in_8bit = pretrained_kwargs["load_in_8bit"] if "load_in_8bit" in pretrained_kwargs else False
|
||||
is_loaded_in_4bit = pretrained_kwargs["load_in_4bit"] if "load_in_4bit" in pretrained_kwargs else False
|
||||
else:
|
||||
is_loaded_in_8bit = getattr(pretrained_model_name_or_path, "is_loaded_in_8bit", False)
|
||||
is_loaded_in_4bit = getattr(pretrained_model_name_or_path, "is_loaded_in_4bit", False)
|
||||
|
@ -20,12 +20,14 @@ from ..import_utils import _LazyModule
|
||||
|
||||
|
||||
_import_structure = {
|
||||
"accuracy_rewards": ["accuracy_reward"],
|
||||
"format_rewards": ["think_format_reward"],
|
||||
"other_rewards": ["get_soft_overlong_punishment"],
|
||||
}
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .accuracy_rewards import accuracy_reward
|
||||
from .format_rewards import think_format_reward
|
||||
from .other_rewards import get_soft_overlong_punishment
|
||||
|
||||
|
93
trl/rewards/accuracy_rewards.py
Normal file
93
trl/rewards/accuracy_rewards.py
Normal file
@ -0,0 +1,93 @@
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from trl.import_utils import is_math_verify_available
|
||||
|
||||
|
||||
if is_math_verify_available():
|
||||
from latex2sympy2_extended import NormalizationConfig
|
||||
from math_verify import LatexExtractionConfig, parse, verify
|
||||
|
||||
|
||||
def accuracy_reward(completions: list[list[dict[str, str]]], solution: list[str], **kwargs) -> list[Optional[float]]:
|
||||
r"""
|
||||
Reward function that checks if the completion is the same as the ground truth.
|
||||
- If both gold and prediction are parseable → use math verification.
|
||||
- If not parseable → compare as normalized text.
|
||||
|
||||
Args:
|
||||
completions (`list[list[dict[str, str]]]`):
|
||||
List of completions to be evaluated. Each completion must be a list of one message, i.e. a dictionary
|
||||
containing the key `"content"` with the value being the text of the completion.
|
||||
solution: (`list[str]`):
|
||||
List of the raw-text solutions to the questions/problems/prompts.
|
||||
**kwargs:
|
||||
Additional keyword arguments. This function does not use them, but they are required in the function
|
||||
signature to ensure compatibility with trainers like [`GRPOTrainer`].
|
||||
Example:
|
||||
```python
|
||||
>>> from trl.rewards import accuracy_reward
|
||||
|
||||
>>> solution = [r"\frac{1}{3}", r"\frac{1}{3}"]
|
||||
>>> completion = [
|
||||
... [{"role": "assistant", "content": r"My answer is \boxed{\frac{1}{3}}"}],
|
||||
... [{"role": "assistant", "content": r"My answer is \boxed{\frac{1}{2}}"}],
|
||||
... ]
|
||||
>>> accuracy_reward(completion, solution)
|
||||
[1.0, 0.0]
|
||||
```
|
||||
"""
|
||||
if not is_math_verify_available():
|
||||
raise ImportError("Please install the `math_verify` package to use accuracy_reward")
|
||||
|
||||
contents = [completion[0]["content"] for completion in completions]
|
||||
rewards = []
|
||||
for content, sol in zip(contents, solution):
|
||||
gold_parsed = parse(
|
||||
sol,
|
||||
extraction_mode="first_match",
|
||||
)
|
||||
if len(gold_parsed) != 0:
|
||||
# We require the answer to be provided in correct latex (no malformed operators)
|
||||
answer_parsed = parse(
|
||||
content,
|
||||
extraction_config=[
|
||||
LatexExtractionConfig(
|
||||
normalization_config=NormalizationConfig(
|
||||
nits=False,
|
||||
malformed_operators=False,
|
||||
basic_latex=True,
|
||||
boxed="all",
|
||||
units=True,
|
||||
),
|
||||
# Ensures that boxed is tried first
|
||||
boxed_match_priority=0,
|
||||
try_extract_without_anchor=False,
|
||||
)
|
||||
],
|
||||
extraction_mode="first_match",
|
||||
)
|
||||
# Compute binary rewards if verifiable, `None` otherwise to skip this example
|
||||
try:
|
||||
reward = float(verify(gold_parsed, answer_parsed))
|
||||
except Exception:
|
||||
reward = None
|
||||
else:
|
||||
# If the gold solution is not parseable, we assign `None` to skip this example
|
||||
reward = float(content.strip().lower() == sol.strip().lower())
|
||||
rewards.append(reward)
|
||||
|
||||
return rewards
|
@ -41,7 +41,7 @@ from trl import (
|
||||
get_dataset,
|
||||
get_peft_config,
|
||||
)
|
||||
from trl.rewards import get_soft_overlong_punishment, think_format_reward
|
||||
from trl.rewards import accuracy_reward, get_soft_overlong_punishment, think_format_reward
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -51,6 +51,7 @@ os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio")
|
||||
|
||||
|
||||
reward_funcs_registry = {
|
||||
"accuracy_reward": accuracy_reward,
|
||||
"think_format_reward": think_format_reward,
|
||||
"get_soft_overlong_punishment": get_soft_overlong_punishment(max_completion_len=1280, soft_punish_cache=256),
|
||||
}
|
||||
@ -68,6 +69,7 @@ class GRPOScriptArguments(ScriptArguments):
|
||||
reward_funcs (`list[str]`, *optional*):
|
||||
Reward functions to use. Supported values are:
|
||||
|
||||
- `"accuracy_reward"`
|
||||
- `"think_format_reward"`
|
||||
- `"get_soft_overlong_punishment"` (used value are `max_completion_len=1280`, `soft_punish_cache=256`)
|
||||
- any dotted import path " (e.g., `'my_lib.rewards.custom_reward'`).
|
||||
@ -83,7 +85,7 @@ class GRPOScriptArguments(ScriptArguments):
|
||||
reward_funcs: Optional[list[str]] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Reward functions to use. Supported values are: `think_format_reward`, "
|
||||
"help": "Reward functions to use. Supported values are: `accuracy_reward`, `think_format_reward`, "
|
||||
"`get_soft_overlong_punishment` (used value are `max_completion_len=1280`, `soft_punish_cache=256`), or "
|
||||
"any dotted import path (e.g., `'my_lib.rewards.custom_reward'`)."
|
||||
},
|
||||
|
@ -41,7 +41,7 @@ from trl import (
|
||||
get_dataset,
|
||||
get_peft_config,
|
||||
)
|
||||
from trl.rewards import get_soft_overlong_punishment, think_format_reward
|
||||
from trl.rewards import accuracy_reward, get_soft_overlong_punishment, think_format_reward
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -51,6 +51,7 @@ os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio")
|
||||
|
||||
|
||||
reward_funcs_registry = {
|
||||
"accuracy_reward": accuracy_reward,
|
||||
"think_format_reward": think_format_reward,
|
||||
"get_soft_overlong_punishment": get_soft_overlong_punishment(max_completion_len=1280, soft_punish_cache=256),
|
||||
}
|
||||
@ -68,6 +69,7 @@ class RLOOScriptArguments(ScriptArguments):
|
||||
reward_funcs (`list[str]`, *optional*):
|
||||
Reward functions to use. Supported values are:
|
||||
|
||||
- `"accuracy_reward"`
|
||||
- `"think_format_reward"`
|
||||
- `"get_soft_overlong_punishment"` (used value are `max_completion_len=1280`, `soft_punish_cache=256`)
|
||||
- any dotted import path " (e.g., `'my_lib.rewards.custom_reward'`).
|
||||
@ -83,7 +85,7 @@ class RLOOScriptArguments(ScriptArguments):
|
||||
reward_funcs: Optional[list[str]] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Reward functions to use. Supported values are: `think_format_reward`, "
|
||||
"help": "Reward functions to use. Supported values are: `accuracy_reward`, `think_format_reward`, "
|
||||
"`get_soft_overlong_punishment` (used value are `max_completion_len=1280`, `soft_punish_cache=256`), or "
|
||||
"any dotted import path (e.g., `'my_lib.rewards.custom_reward'`)."
|
||||
},
|
||||
|
@ -44,10 +44,12 @@ from .utils import log_table_to_comet_experiment
|
||||
|
||||
|
||||
if is_rich_available():
|
||||
from rich.columns import Columns
|
||||
from rich.console import Console, Group
|
||||
from rich.live import Live
|
||||
from rich.panel import Panel
|
||||
from rich.progress import Progress
|
||||
from rich.table import Table
|
||||
|
||||
if is_wandb_available():
|
||||
import wandb
|
||||
@ -152,74 +154,105 @@ class RichProgressCallback(TrainerCallback):
|
||||
raise ImportError("RichProgressCallback requires the `rich` extra. To install, run `pip install rich`.")
|
||||
|
||||
self.training_bar = None
|
||||
self.prediction_bar = None
|
||||
|
||||
self.training_task_id = None
|
||||
self.prediction_task_id = None
|
||||
|
||||
self.evaluation_bar = None
|
||||
self.training_task = None
|
||||
self.evaluation_task = None
|
||||
self.rich_group = None
|
||||
self.rich_console = None
|
||||
|
||||
self.training_status = None
|
||||
self.current_step = None
|
||||
|
||||
def on_train_begin(self, args, state, control, **kwargs):
|
||||
if state.is_world_process_zero:
|
||||
self.training_bar = Progress()
|
||||
self.prediction_bar = Progress()
|
||||
if not state.is_world_process_zero:
|
||||
return
|
||||
|
||||
self.rich_console = Console()
|
||||
|
||||
self.training_status = self.rich_console.status("Nothing to log yet ...")
|
||||
|
||||
self.rich_group = Live(Panel(Group(self.training_bar, self.prediction_bar, self.training_status)))
|
||||
self.rich_group.start()
|
||||
|
||||
self.training_task_id = self.training_bar.add_task("[blue]Training the model", total=state.max_steps)
|
||||
self.current_step = 0
|
||||
self.training_bar = Progress()
|
||||
self.evaluation_bar = Progress()
|
||||
self.rich_console = Console()
|
||||
self.training_status = self.rich_console.status("Nothing to log yet ...")
|
||||
self.rich_group = Live(Panel(Group(self.training_bar, self.evaluation_bar, self.training_status)))
|
||||
self.rich_group.start()
|
||||
self.training_task = self.training_bar.add_task("[blue]Training ", total=state.max_steps)
|
||||
self.current_step = 0
|
||||
|
||||
def on_step_end(self, args, state, control, **kwargs):
|
||||
if state.is_world_process_zero:
|
||||
self.training_bar.update(self.training_task_id, advance=state.global_step - self.current_step, update=True)
|
||||
self.current_step = state.global_step
|
||||
if not state.is_world_process_zero:
|
||||
return
|
||||
|
||||
self.training_bar.update(self.training_task, advance=state.global_step - self.current_step, update=True)
|
||||
self.current_step = state.global_step
|
||||
|
||||
def on_prediction_step(self, args, state, control, eval_dataloader=None, **kwargs):
|
||||
if state.is_world_process_zero and has_length(eval_dataloader):
|
||||
if self.prediction_task_id is None:
|
||||
self.prediction_task_id = self.prediction_bar.add_task(
|
||||
"[blue]Predicting on the evaluation dataset", total=len(eval_dataloader)
|
||||
)
|
||||
self.prediction_bar.update(self.prediction_task_id, advance=1, update=True)
|
||||
if not state.is_world_process_zero:
|
||||
return
|
||||
|
||||
if has_length(eval_dataloader):
|
||||
if self.evaluation_task is None:
|
||||
self.evaluation_task = self.evaluation_bar.add_task("[blue]Evaluation", total=len(eval_dataloader))
|
||||
self.evaluation_bar.update(self.evaluation_task, advance=1, update=True)
|
||||
|
||||
def on_evaluate(self, args, state, control, **kwargs):
|
||||
if state.is_world_process_zero:
|
||||
if self.prediction_task_id is not None:
|
||||
self.prediction_bar.remove_task(self.prediction_task_id)
|
||||
self.prediction_task_id = None
|
||||
if not state.is_world_process_zero:
|
||||
return
|
||||
|
||||
if self.evaluation_task is not None:
|
||||
self.evaluation_bar.remove_task(self.evaluation_task)
|
||||
self.evaluation_task = None
|
||||
|
||||
def on_predict(self, args, state, control, **kwargs):
|
||||
if state.is_world_process_zero:
|
||||
if self.prediction_task_id is not None:
|
||||
self.prediction_bar.remove_task(self.prediction_task_id)
|
||||
self.prediction_task_id = None
|
||||
if not state.is_world_process_zero:
|
||||
return
|
||||
|
||||
if self.evaluation_task is not None:
|
||||
self.evaluation_bar.remove_task(self.evaluation_task)
|
||||
self.evaluation_task = None
|
||||
|
||||
def on_log(self, args, state, control, logs=None, **kwargs):
|
||||
if state.is_world_process_zero and self.training_bar is not None:
|
||||
_ = logs.pop("total_flos", None)
|
||||
self.training_status.update(f"[bold green]Status = {str(logs)}")
|
||||
if not (state.is_world_process_zero and self.training_bar):
|
||||
return
|
||||
|
||||
# Group keys by top-level prefix
|
||||
grouped_logs = {}
|
||||
for key, value in logs.items():
|
||||
parts = key.split("/")
|
||||
group = parts[0] if len(parts) > 1 else None
|
||||
subkey = "/".join(parts[1:]) if len(parts) > 1 else key
|
||||
grouped_logs.setdefault(group, {})[subkey] = value
|
||||
|
||||
# Create a table per group
|
||||
tables = []
|
||||
for group_name, metrics in grouped_logs.items():
|
||||
table = Table(
|
||||
title=f"[bold blue]{group_name}[/]" if group_name else None, header_style="bold magenta", box=None
|
||||
)
|
||||
table.add_column("Metric", justify="left", no_wrap=True)
|
||||
table.add_column("Value", justify="right")
|
||||
|
||||
for metric, val in metrics.items():
|
||||
formatted = f"{val:.3f}" if isinstance(val, (float, int)) else str(val)
|
||||
table.add_row(metric, formatted)
|
||||
|
||||
tables.append(Panel(table, border_style="cyan", padding=(0, 1)))
|
||||
|
||||
# Arrange tables in columns using Columns
|
||||
column_layout = Columns(tables, equal=False, expand=True)
|
||||
self.training_status.update(
|
||||
Panel(column_layout, title=f"[bold green]Step {state.global_step}[/bold green]", border_style="green")
|
||||
)
|
||||
|
||||
def on_train_end(self, args, state, control, **kwargs):
|
||||
if state.is_world_process_zero:
|
||||
self.rich_group.stop()
|
||||
if not state.is_world_process_zero:
|
||||
return
|
||||
|
||||
self.training_bar = None
|
||||
self.prediction_bar = None
|
||||
self.training_task_id = None
|
||||
self.prediction_task_id = None
|
||||
self.rich_group = None
|
||||
self.rich_console = None
|
||||
self.training_status = None
|
||||
self.current_step = None
|
||||
self.rich_group.stop()
|
||||
self.training_bar = None
|
||||
self.evaluation_bar = None
|
||||
self.training_task = None
|
||||
self.evaluation_task = None
|
||||
self.rich_group = None
|
||||
self.rich_console = None
|
||||
self.training_status = None
|
||||
self.current_step = None
|
||||
|
||||
|
||||
def _win_rate_completions_df(
|
||||
|
@ -177,6 +177,9 @@ class DataCollatorForPreference(DataCollatorMixin):
|
||||
if "ref_chosen_logps" in examples[0] and "ref_rejected_logps" in examples[0]:
|
||||
output["ref_chosen_logps"] = ref_chosen_logps
|
||||
output["ref_rejected_logps"] = ref_rejected_logps
|
||||
if "token_type_ids" in examples[0]:
|
||||
token_type_ids = [torch.tensor(example["token_type_ids"]) for example in examples]
|
||||
output["token_type_ids"] = pad(token_type_ids, padding_value=0, padding_side="left")
|
||||
|
||||
return output
|
||||
|
||||
@ -790,6 +793,8 @@ class DPOTrainer(BaseTrainer):
|
||||
output["pixel_attention_mask"] = processed_features["pixel_attention_mask"][0]
|
||||
if "image_sizes" in processed_features:
|
||||
output["image_sizes"] = processed_features["image_sizes"][0]
|
||||
if "token_type_ids" in processed_features:
|
||||
output["token_type_ids"] = processed_features["token_type_ids"][0]
|
||||
|
||||
return output
|
||||
|
||||
@ -804,6 +809,7 @@ class DPOTrainer(BaseTrainer):
|
||||
"chosen_input_ids",
|
||||
"rejected_input_ids",
|
||||
"image_sizes",
|
||||
"token_type_ids",
|
||||
"ref_chosen_logps",
|
||||
"ref_rejected_logps",
|
||||
]
|
||||
@ -991,6 +997,8 @@ class DPOTrainer(BaseTrainer):
|
||||
)
|
||||
if "image_sizes" in batch:
|
||||
output["image_sizes"] = torch.cat([batch["image_sizes"], batch["image_sizes"]], dim=0)
|
||||
if "token_type_ids" in batch:
|
||||
output["token_type_ids"] = torch.cat((batch["token_type_ids"], batch["token_type_ids"]))
|
||||
|
||||
# Concatenate the chosen and rejected completions
|
||||
max_completion_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1])
|
||||
@ -1516,6 +1524,9 @@ class DPOTrainer(BaseTrainer):
|
||||
# Concatenate the prompt and completion inputs
|
||||
input_ids = torch.cat((prompt_input_ids, completion_input_ids), dim=1)
|
||||
attention_mask = torch.cat((prompt_attention_mask, completion_attention_mask), dim=1)
|
||||
if "token_type_ids" in concatenated_batch:
|
||||
prompt_token_type_ids = concatenated_batch["token_type_ids"]
|
||||
token_type_ids = pad_to_length(prompt_token_type_ids, input_ids.shape[1], 0)
|
||||
# Mask the prompt but not the completion for the loss
|
||||
loss_mask = torch.cat(
|
||||
(torch.zeros_like(prompt_attention_mask), completion_attention_mask),
|
||||
@ -1528,7 +1539,12 @@ class DPOTrainer(BaseTrainer):
|
||||
# Flush left to reduce the memory usage
|
||||
# [[0, 0, x, x, x, x], -> [[x, x, x, x],
|
||||
# [0, x, x, x, 0, 0]] [x, x, x, 0]]
|
||||
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
|
||||
if "token_type_ids" in concatenated_batch:
|
||||
attention_mask, input_ids, loss_mask, token_type_ids = flush_left(
|
||||
attention_mask, input_ids, loss_mask, token_type_ids
|
||||
)
|
||||
else:
|
||||
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
|
||||
attention_mask = attention_mask[:, : self.max_length]
|
||||
input_ids = input_ids[:, : self.max_length]
|
||||
loss_mask = loss_mask[:, : self.max_length]
|
||||
@ -1536,11 +1552,22 @@ class DPOTrainer(BaseTrainer):
|
||||
# Flush right before truncating left, then flush left
|
||||
# [[0, 0, x, x, x, x], -> [[0, 0, x, x],
|
||||
# [0, x, x, x, 0, 0]] [0, x, x, x]]
|
||||
attention_mask, input_ids, loss_mask = flush_right(attention_mask, input_ids, loss_mask)
|
||||
if "token_type_ids" in concatenated_batch:
|
||||
attention_mask, input_ids, loss_mask, token_type_ids = flush_left(
|
||||
attention_mask, input_ids, loss_mask, token_type_ids
|
||||
)
|
||||
token_type_ids = token_type_ids[:, -self.max_length :]
|
||||
else:
|
||||
attention_mask, input_ids, loss_mask = flush_right(attention_mask, input_ids, loss_mask)
|
||||
input_ids = input_ids[:, -self.max_length :]
|
||||
attention_mask = attention_mask[:, -self.max_length :]
|
||||
loss_mask = loss_mask[:, -self.max_length :]
|
||||
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
|
||||
if "token_type_ids" in concatenated_batch:
|
||||
attention_mask, input_ids, loss_mask, token_type_ids = flush_left(
|
||||
attention_mask, input_ids, loss_mask, token_type_ids
|
||||
)
|
||||
else:
|
||||
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown truncation mode: '{self.truncation_mode}'. Should be one of ['keep_end', "
|
||||
@ -1550,7 +1577,15 @@ class DPOTrainer(BaseTrainer):
|
||||
# Flush left to reduce the memory usage
|
||||
# [[0, 0, x, x, x, x], -> [[x, x, x, x],
|
||||
# [0, x, x, x, 0, 0]] [x, x, x, 0]]
|
||||
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
|
||||
if "token_type_ids" in concatenated_batch:
|
||||
attention_mask, input_ids, loss_mask, token_type_ids = flush_left(
|
||||
attention_mask, input_ids, loss_mask, token_type_ids
|
||||
)
|
||||
else:
|
||||
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
|
||||
|
||||
if "token_type_ids" in concatenated_batch:
|
||||
model_kwargs["token_type_ids"] = token_type_ids
|
||||
|
||||
if self.use_logits_to_keep:
|
||||
# Compute logits_to_keep based on loss_mask pattern:
|
||||
|
Reference in New Issue
Block a user