⚖️ Add option not to scale rewards (Dr. GRPO) (#3135)

This commit is contained in:
Quentin Gallouédec
2025-03-22 13:47:52 -07:00
committed by GitHub
parent 0f26049ea2
commit 9b38b0b5ee
3 changed files with 49 additions and 1 deletions

View File

@ -914,3 +914,34 @@ class GRPOTrainerTester(unittest.TestCase):
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
def test_training_no_scale_rewards(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = GRPOConfig(
output_dir=tmp_dir,
learning_rate=0.1, # increase the learning rate to speed up the test
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=32, # reduce the completion length to reduce memory usage
scale_rewards=False,
report_to="none",
)
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
args=training_args,
train_dataset=dataset,
)
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
trainer.train()
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")

View File

@ -109,6 +109,12 @@ class GRPOConfig(TrainingArguments):
reward_weights (`list[float]` or `None`, *optional*, defaults to `None`):
Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are
weighted equally with weight `1.0`.
scale_rewards (`bool`, *optional*, defaults to `True`):
Whether to scale the rewards by dividing them by their standard deviation. If `True` (default), the rewards
are normalized by the standard deviation, ensuring they have unit variance. If `False`, no scaling is
applied. The [Dr. GRPO](https://github.com/sail-sg/understand-r1-zero/blob/main/understand-r1-zero.pdf)
paper recommends not scaling the rewards, as scaling by the standard deviation introduces a question-level
difficulty bias.
sync_ref_model (`bool`, *optional*, defaults to `False`):
Whether to synchronize the reference model with the active model every `ref_model_sync_steps` steps, using
the `ref_model_mixup_alpha` parameter. This synchronization originites from the
@ -280,6 +286,15 @@ class GRPOConfig(TrainingArguments):
"rewards are weighted equally with weight `1.0`."
},
)
scale_rewards: bool = field(
default=True,
metadata={
"help": "Whether to scale the rewards by dividing them by their standard deviation. If `True` (default), "
"the rewards are normalized by the standard deviation, ensuring they have unit variance. If `False`, no "
"scaling is applied. The Dr. GRPO paper recommends not scaling the rewards, as scaling by the standard "
"deviation introduces a question-level difficulty bias."
},
)
sync_ref_model: bool = field(
default=False,
metadata={

View File

@ -827,7 +827,9 @@ class GRPOTrainer(Trainer):
# Normalize the rewards to compute the advantages
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
advantages = rewards - mean_grouped_rewards
if self.args.scale_rewards:
advantages = advantages / (std_grouped_rewards + 1e-4)
# Slice to keep only the local part of the data
process_slice = slice(