Overlong-filtering for GRPO (#3248)

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
This commit is contained in:
Shirin Yamani
2025-04-08 12:52:52 -06:00
committed by GitHub
parent 7e170612a4
commit 1d7b8c4f70
3 changed files with 113 additions and 3 deletions

View File

@ -14,6 +14,7 @@
import tempfile
import unittest
from unittest.mock import patch
import torch
from datasets import load_dataset
@ -915,3 +916,93 @@ class GRPOTrainerTester(unittest.TestCase):
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
@patch("transformers.generation.utils.GenerationMixin.generate")
def test_training_with_mask_truncated_completions(self, mock_generate):
"""Test that training works with mask_truncated_completions=True parameter."""
# We mock the generate method because the model's random weights make it extremely unlikely to produce a
# sequence containing the EOS token within the allowed max_completion_length. As a result, all tokens are
# masked in the loss, the model doesn't update, and the final check (which verifies the update) fails.
def fake_generate(prompt_ids, **kwargs):
# pad_token_id = 151643; eos_token_id = 151645
completions_ids = torch.tensor(
[
[1, 2, 3, 4, 5, 6, 7, 8], # this one is truncated
[9, 10, 11, 151645, 151643, 151643, 151643, 151643], # this one contains eos
[12, 13, 14, 15, 16, 17, 18, 151645], # particular case, eos is generated just within the limit
],
device=prompt_ids.device,
)
return torch.cat([prompt_ids, completions_ids], dim=1)
mock_generate.side_effect = fake_generate
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = GRPOConfig(
output_dir=tmp_dir,
learning_rate=0.1, # increase the learning rate to speed up the test
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=8, # reduce the completion length to reduce memory usage
mask_truncated_completions=True, # Enable masking of truncated completions
report_to="none",
)
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
args=training_args,
train_dataset=dataset,
)
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
trainer.train()
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
def test_training_with_mask_truncated_completions_all_masked(self):
"""
Test that when all generated completions are truncated (i.e., none contain an EOS token), and
mask_truncated_completions=True, the model receives no effective learning signal and therefore does not update
its parameters.
Here, we don't mock the generate method, be we rely on the fact that the model the probability of generating
the EOS token is extremely low, so all generated completions are truncated.
"""
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = GRPOConfig(
output_dir=tmp_dir,
learning_rate=0.1, # increase the learning rate to speed up the test
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=8, # reduce the completion length to reduce memory usage
mask_truncated_completions=True, # Enable masking of truncated completions
report_to="none",
)
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
args=training_args,
train_dataset=dataset,
)
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
trainer.train()
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertTrue(torch.equal(param, new_param), f"Parameter {n} has changed.")

View File

@ -115,6 +115,10 @@ class GRPOConfig(TrainingArguments):
applied. The [Dr. GRPO](https://github.com/sail-sg/understand-r1-zero/blob/main/understand-r1-zero.pdf)
paper recommends not scaling the rewards, as scaling by the standard deviation introduces a question-level
difficulty bias.
mask_truncated_completions (`bool`, *optional*, defaults to `False`):
When enabled, truncated completions are excluded from the loss calculation, preventing them from being
incorrectly penalized and introducing noise during training. According to the
[DAPO](https://huggingface.co/papers/2503.14476) paper, this is a good practice for training stability.
sync_ref_model (`bool`, *optional*, defaults to `False`):
Whether to synchronize the reference model with the active model every `ref_model_sync_steps` steps, using
the `ref_model_mixup_alpha` parameter. This synchronization originites from the
@ -302,6 +306,14 @@ class GRPOConfig(TrainingArguments):
"deviation introduces a question-level difficulty bias."
},
)
mask_truncated_completions: bool = field(
default=False,
metadata={
"help": "When enabled, truncated completions are excluded from the loss calculation, preventing them from "
"being incorrectly penalized and introducing noise during training. According to the DAPO paper, this is "
"a good practice for training stability."
},
)
sync_ref_model: bool = field(
default=False,
metadata={

View File

@ -409,6 +409,7 @@ class GRPOTrainer(Trainer):
self.repetition_penalty = args.repetition_penalty
self.use_vllm = args.use_vllm
self.use_liger_loss = args.use_liger_loss
self.mask_truncated_completions = args.mask_truncated_completions
# Datasets
if (
@ -810,6 +811,11 @@ class GRPOTrainer(Trainer):
sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
# If mask_truncated_completions is enabled, zero out truncated completions in completion_mask
if self.mask_truncated_completions:
truncated_completions = ~is_eos.any(dim=1)
completion_mask = completion_mask * (~truncated_completions).unsqueeze(1).int()
# Concatenate prompt_mask with completion_mask for logit computation
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C)
@ -1047,21 +1053,22 @@ class GRPOTrainer(Trainer):
per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
if self.beta != 0.0:
per_token_loss = per_token_loss + self.beta * per_token_kl
loss = (per_token_loss * completion_mask).sum() / completion_mask.sum()
loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp(min=1.0)
# Log the metrics
mode = "eval" if self.control.should_evaluate else "train"
if self.beta != 0.0:
mean_kl = (per_token_kl * completion_mask).sum() / completion_mask.sum()
self._metrics[mode]["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
self._metrics[mode]["kl"].append(self.accelerator.gather_for_metrics(mean_kl).nanmean().item())
# Compute the clip ratio
is_clipped = ((coef_1 < 1 - self.epsilon_low) & (advantages.unsqueeze(1) < 0)) | (
(coef_1 > 1 + self.epsilon_high) & (advantages.unsqueeze(1) > 0)
)
clip_ratio = (is_clipped * completion_mask).sum() / completion_mask.sum()
self._metrics[mode]["clip_ratio"].append(self.accelerator.gather_for_metrics(clip_ratio).mean().item())
self._metrics[mode]["clip_ratio"].append(self.accelerator.gather_for_metrics(clip_ratio).nanmean().item())
return loss
def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys: Optional[list[str]] = None):