mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
☕ Overlong-filtering for GRPO (#3248)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
This commit is contained in:
@ -14,6 +14,7 @@
|
||||
|
||||
import tempfile
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
@ -915,3 +916,93 @@ class GRPOTrainerTester(unittest.TestCase):
|
||||
for n, param in previous_trainable_params.items():
|
||||
new_param = trainer.model.get_parameter(n)
|
||||
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
|
||||
|
||||
@patch("transformers.generation.utils.GenerationMixin.generate")
|
||||
def test_training_with_mask_truncated_completions(self, mock_generate):
|
||||
"""Test that training works with mask_truncated_completions=True parameter."""
|
||||
|
||||
# We mock the generate method because the model's random weights make it extremely unlikely to produce a
|
||||
# sequence containing the EOS token within the allowed max_completion_length. As a result, all tokens are
|
||||
# masked in the loss, the model doesn't update, and the final check (which verifies the update) fails.
|
||||
def fake_generate(prompt_ids, **kwargs):
|
||||
# pad_token_id = 151643; eos_token_id = 151645
|
||||
completions_ids = torch.tensor(
|
||||
[
|
||||
[1, 2, 3, 4, 5, 6, 7, 8], # this one is truncated
|
||||
[9, 10, 11, 151645, 151643, 151643, 151643, 151643], # this one contains eos
|
||||
[12, 13, 14, 15, 16, 17, 18, 151645], # particular case, eos is generated just within the limit
|
||||
],
|
||||
device=prompt_ids.device,
|
||||
)
|
||||
return torch.cat([prompt_ids, completions_ids], dim=1)
|
||||
|
||||
mock_generate.side_effect = fake_generate
|
||||
|
||||
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
training_args = GRPOConfig(
|
||||
output_dir=tmp_dir,
|
||||
learning_rate=0.1, # increase the learning rate to speed up the test
|
||||
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
|
||||
num_generations=3, # reduce the number of generations to reduce memory usage
|
||||
max_completion_length=8, # reduce the completion length to reduce memory usage
|
||||
mask_truncated_completions=True, # Enable masking of truncated completions
|
||||
report_to="none",
|
||||
)
|
||||
trainer = GRPOTrainer(
|
||||
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
|
||||
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
)
|
||||
|
||||
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
||||
|
||||
trainer.train()
|
||||
|
||||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||
|
||||
# Check that the params have changed
|
||||
for n, param in previous_trainable_params.items():
|
||||
new_param = trainer.model.get_parameter(n)
|
||||
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
|
||||
|
||||
def test_training_with_mask_truncated_completions_all_masked(self):
|
||||
"""
|
||||
Test that when all generated completions are truncated (i.e., none contain an EOS token), and
|
||||
mask_truncated_completions=True, the model receives no effective learning signal and therefore does not update
|
||||
its parameters.
|
||||
|
||||
Here, we don't mock the generate method, be we rely on the fact that the model the probability of generating
|
||||
the EOS token is extremely low, so all generated completions are truncated.
|
||||
"""
|
||||
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
training_args = GRPOConfig(
|
||||
output_dir=tmp_dir,
|
||||
learning_rate=0.1, # increase the learning rate to speed up the test
|
||||
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
|
||||
num_generations=3, # reduce the number of generations to reduce memory usage
|
||||
max_completion_length=8, # reduce the completion length to reduce memory usage
|
||||
mask_truncated_completions=True, # Enable masking of truncated completions
|
||||
report_to="none",
|
||||
)
|
||||
trainer = GRPOTrainer(
|
||||
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
|
||||
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
)
|
||||
|
||||
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
||||
|
||||
trainer.train()
|
||||
|
||||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||
|
||||
# Check that the params have changed
|
||||
for n, param in previous_trainable_params.items():
|
||||
new_param = trainer.model.get_parameter(n)
|
||||
self.assertTrue(torch.equal(param, new_param), f"Parameter {n} has changed.")
|
||||
|
@ -115,6 +115,10 @@ class GRPOConfig(TrainingArguments):
|
||||
applied. The [Dr. GRPO](https://github.com/sail-sg/understand-r1-zero/blob/main/understand-r1-zero.pdf)
|
||||
paper recommends not scaling the rewards, as scaling by the standard deviation introduces a question-level
|
||||
difficulty bias.
|
||||
mask_truncated_completions (`bool`, *optional*, defaults to `False`):
|
||||
When enabled, truncated completions are excluded from the loss calculation, preventing them from being
|
||||
incorrectly penalized and introducing noise during training. According to the
|
||||
[DAPO](https://huggingface.co/papers/2503.14476) paper, this is a good practice for training stability.
|
||||
sync_ref_model (`bool`, *optional*, defaults to `False`):
|
||||
Whether to synchronize the reference model with the active model every `ref_model_sync_steps` steps, using
|
||||
the `ref_model_mixup_alpha` parameter. This synchronization originites from the
|
||||
@ -302,6 +306,14 @@ class GRPOConfig(TrainingArguments):
|
||||
"deviation introduces a question-level difficulty bias."
|
||||
},
|
||||
)
|
||||
mask_truncated_completions: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "When enabled, truncated completions are excluded from the loss calculation, preventing them from "
|
||||
"being incorrectly penalized and introducing noise during training. According to the DAPO paper, this is "
|
||||
"a good practice for training stability."
|
||||
},
|
||||
)
|
||||
sync_ref_model: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
|
@ -409,6 +409,7 @@ class GRPOTrainer(Trainer):
|
||||
self.repetition_penalty = args.repetition_penalty
|
||||
self.use_vllm = args.use_vllm
|
||||
self.use_liger_loss = args.use_liger_loss
|
||||
self.mask_truncated_completions = args.mask_truncated_completions
|
||||
|
||||
# Datasets
|
||||
if (
|
||||
@ -810,6 +811,11 @@ class GRPOTrainer(Trainer):
|
||||
sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
|
||||
completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
|
||||
|
||||
# If mask_truncated_completions is enabled, zero out truncated completions in completion_mask
|
||||
if self.mask_truncated_completions:
|
||||
truncated_completions = ~is_eos.any(dim=1)
|
||||
completion_mask = completion_mask * (~truncated_completions).unsqueeze(1).int()
|
||||
|
||||
# Concatenate prompt_mask with completion_mask for logit computation
|
||||
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C)
|
||||
|
||||
@ -1047,21 +1053,22 @@ class GRPOTrainer(Trainer):
|
||||
per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
|
||||
if self.beta != 0.0:
|
||||
per_token_loss = per_token_loss + self.beta * per_token_kl
|
||||
loss = (per_token_loss * completion_mask).sum() / completion_mask.sum()
|
||||
|
||||
loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp(min=1.0)
|
||||
|
||||
# Log the metrics
|
||||
mode = "eval" if self.control.should_evaluate else "train"
|
||||
|
||||
if self.beta != 0.0:
|
||||
mean_kl = (per_token_kl * completion_mask).sum() / completion_mask.sum()
|
||||
self._metrics[mode]["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
|
||||
self._metrics[mode]["kl"].append(self.accelerator.gather_for_metrics(mean_kl).nanmean().item())
|
||||
|
||||
# Compute the clip ratio
|
||||
is_clipped = ((coef_1 < 1 - self.epsilon_low) & (advantages.unsqueeze(1) < 0)) | (
|
||||
(coef_1 > 1 + self.epsilon_high) & (advantages.unsqueeze(1) > 0)
|
||||
)
|
||||
clip_ratio = (is_clipped * completion_mask).sum() / completion_mask.sum()
|
||||
self._metrics[mode]["clip_ratio"].append(self.accelerator.gather_for_metrics(clip_ratio).mean().item())
|
||||
self._metrics[mode]["clip_ratio"].append(self.accelerator.gather_for_metrics(clip_ratio).nanmean().item())
|
||||
return loss
|
||||
|
||||
def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys: Optional[list[str]] = None):
|
||||
|
Reference in New Issue
Block a user