mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
⚖️ Add option not to scale rewards (Dr. GRPO) (#3135)
This commit is contained in:
committed by
GitHub
parent
0f26049ea2
commit
9b38b0b5ee
@ -914,3 +914,34 @@ class GRPOTrainerTester(unittest.TestCase):
|
||||
for n, param in previous_trainable_params.items():
|
||||
new_param = trainer.model.get_parameter(n)
|
||||
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
|
||||
|
||||
def test_training_no_scale_rewards(self):
|
||||
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
training_args = GRPOConfig(
|
||||
output_dir=tmp_dir,
|
||||
learning_rate=0.1, # increase the learning rate to speed up the test
|
||||
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
|
||||
num_generations=3, # reduce the number of generations to reduce memory usage
|
||||
max_completion_length=32, # reduce the completion length to reduce memory usage
|
||||
scale_rewards=False,
|
||||
report_to="none",
|
||||
)
|
||||
trainer = GRPOTrainer(
|
||||
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
|
||||
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
)
|
||||
|
||||
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
||||
|
||||
trainer.train()
|
||||
|
||||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||
|
||||
# Check that the params have changed
|
||||
for n, param in previous_trainable_params.items():
|
||||
new_param = trainer.model.get_parameter(n)
|
||||
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")
|
||||
|
@ -109,6 +109,12 @@ class GRPOConfig(TrainingArguments):
|
||||
reward_weights (`list[float]` or `None`, *optional*, defaults to `None`):
|
||||
Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are
|
||||
weighted equally with weight `1.0`.
|
||||
scale_rewards (`bool`, *optional*, defaults to `True`):
|
||||
Whether to scale the rewards by dividing them by their standard deviation. If `True` (default), the rewards
|
||||
are normalized by the standard deviation, ensuring they have unit variance. If `False`, no scaling is
|
||||
applied. The [Dr. GRPO](https://github.com/sail-sg/understand-r1-zero/blob/main/understand-r1-zero.pdf)
|
||||
paper recommends not scaling the rewards, as scaling by the standard deviation introduces a question-level
|
||||
difficulty bias.
|
||||
sync_ref_model (`bool`, *optional*, defaults to `False`):
|
||||
Whether to synchronize the reference model with the active model every `ref_model_sync_steps` steps, using
|
||||
the `ref_model_mixup_alpha` parameter. This synchronization originites from the
|
||||
@ -280,6 +286,15 @@ class GRPOConfig(TrainingArguments):
|
||||
"rewards are weighted equally with weight `1.0`."
|
||||
},
|
||||
)
|
||||
scale_rewards: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "Whether to scale the rewards by dividing them by their standard deviation. If `True` (default), "
|
||||
"the rewards are normalized by the standard deviation, ensuring they have unit variance. If `False`, no "
|
||||
"scaling is applied. The Dr. GRPO paper recommends not scaling the rewards, as scaling by the standard "
|
||||
"deviation introduces a question-level difficulty bias."
|
||||
},
|
||||
)
|
||||
sync_ref_model: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
|
@ -827,7 +827,9 @@ class GRPOTrainer(Trainer):
|
||||
# Normalize the rewards to compute the advantages
|
||||
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
|
||||
std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
|
||||
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
|
||||
advantages = rewards - mean_grouped_rewards
|
||||
if self.args.scale_rewards:
|
||||
advantages = advantages / (std_grouped_rewards + 1e-4)
|
||||
|
||||
# Slice to keep only the local part of the data
|
||||
process_slice = slice(
|
||||
|
Reference in New Issue
Block a user