mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
Compare commits
9 Commits
7e9c6e45d5
...
grpo-adv-r
Author | SHA1 | Date | |
---|---|---|---|
d234d7fc02 | |||
889dda808c | |||
dd9a6e1d29 | |||
e481bacf49 | |||
acba243b60 | |||
c16dc0878b | |||
2a461eec0a | |||
5e925e1056 | |||
275adebaf3 |
131
tests/test_repad.py
Normal file
131
tests/test_repad.py
Normal file
@ -0,0 +1,131 @@
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
|
||||
from trl.trainer.grpo_replay_buffer import repad
|
||||
|
||||
|
||||
PAD_TOKEN_ID = 123
|
||||
|
||||
|
||||
def test_repad_basic_padding():
|
||||
sample = [
|
||||
{
|
||||
"prompt_ids": torch.LongTensor([1, 2, 3]),
|
||||
"prompt_mask": torch.LongTensor([1, 1, 0]),
|
||||
"completion_ids": torch.LongTensor([5, 6, 7, 8]),
|
||||
"completion_mask": torch.LongTensor([1, 1, 1, 0]),
|
||||
"old_per_token_logps": torch.tensor([0.1, 0.2, 0.3, 0.4]),
|
||||
"ref_per_token_logps": torch.tensor([0.0, -0.1, -0.2, -0.3]),
|
||||
},
|
||||
{
|
||||
"prompt_ids": torch.LongTensor([4, 5]),
|
||||
"prompt_mask": torch.LongTensor([1, 1]),
|
||||
"completion_ids": torch.LongTensor([9, 10]),
|
||||
"completion_mask": torch.LongTensor([1, 1]),
|
||||
"old_per_token_logps": torch.tensor([-0.5, -0.6]),
|
||||
"ref_per_token_logps": torch.tensor([0.5, 0.6]),
|
||||
},
|
||||
]
|
||||
|
||||
padded = repad(deepcopy(sample), padding_value=PAD_TOKEN_ID)
|
||||
|
||||
assert len(padded[0]["prompt_ids"]) == 2
|
||||
assert len(padded[0]["completion_ids"]) == 3
|
||||
|
||||
for ex in padded:
|
||||
# All sequences in same batch should have same length
|
||||
assert len(ex["prompt_ids"]) == len(padded[0]["prompt_ids"])
|
||||
assert len(ex["prompt_mask"]) == len(padded[0]["prompt_mask"])
|
||||
assert len(ex["completion_ids"]) == len(padded[0]["completion_ids"])
|
||||
assert len(ex["completion_mask"]) == len(padded[0]["completion_mask"])
|
||||
|
||||
# Mask and ids should match in shape
|
||||
assert ex["prompt_ids"].shape == ex["prompt_mask"].shape
|
||||
assert ex["completion_ids"].shape == ex["completion_mask"].shape
|
||||
|
||||
|
||||
def test_repad_logps_padding():
|
||||
sample = [
|
||||
{
|
||||
"prompt_ids": torch.LongTensor([1]),
|
||||
"prompt_mask": torch.LongTensor([1]),
|
||||
"completion_ids": torch.LongTensor([2, 3, 4]),
|
||||
"completion_mask": torch.LongTensor([1, 1, 0]),
|
||||
"old_per_token_logps": torch.tensor([-0.1, -0.2, -0.3]),
|
||||
"ref_per_token_logps": torch.tensor([-0.5, -0.6, -0.7]),
|
||||
},
|
||||
{
|
||||
"prompt_ids": torch.LongTensor([5, 6]),
|
||||
"prompt_mask": torch.LongTensor([1, 1]),
|
||||
"completion_ids": torch.LongTensor([7, 8]),
|
||||
"completion_mask": torch.LongTensor([1, 1]),
|
||||
"old_per_token_logps": torch.tensor([0.4, 0.5]),
|
||||
"ref_per_token_logps": torch.tensor([0.6, 0.7]),
|
||||
},
|
||||
]
|
||||
|
||||
padded = repad(deepcopy(sample), padding_value=PAD_TOKEN_ID)
|
||||
|
||||
for logps in ["old_per_token_logps", "ref_per_token_logps"]:
|
||||
for ex in padded:
|
||||
assert len(ex[logps]) == len(padded[0][logps])
|
||||
assert isinstance(ex[logps], torch.Tensor)
|
||||
|
||||
|
||||
def test_repad_empty_masks():
|
||||
sample = [
|
||||
{
|
||||
"prompt_ids": torch.tensor([0]),
|
||||
"prompt_mask": torch.tensor([0]),
|
||||
"completion_ids": torch.tensor([0]),
|
||||
"completion_mask": torch.tensor([0]),
|
||||
"old_per_token_logps": torch.tensor([0.0]),
|
||||
"ref_per_token_logps": torch.tensor([0.0]),
|
||||
},
|
||||
{
|
||||
"prompt_ids": torch.tensor([1]),
|
||||
"prompt_mask": torch.tensor([0]),
|
||||
"completion_ids": torch.tensor([1]),
|
||||
"completion_mask": torch.tensor([0]),
|
||||
"old_per_token_logps": torch.tensor([0.0]),
|
||||
"ref_per_token_logps": torch.tensor([0.0]),
|
||||
},
|
||||
{
|
||||
"prompt_ids": torch.tensor([1, 1]),
|
||||
"prompt_mask": torch.tensor([0, 1]),
|
||||
"completion_ids": torch.tensor([1, 2]),
|
||||
"completion_mask": torch.tensor([1, 0]),
|
||||
"old_per_token_logps": torch.tensor([0.0, 1.0]),
|
||||
"ref_per_token_logps": torch.tensor([0.0, 1.0]),
|
||||
},
|
||||
{
|
||||
"prompt_ids": torch.tensor([1, 1]),
|
||||
"prompt_mask": torch.tensor([1, 1]),
|
||||
"completion_ids": torch.tensor([1, 2]),
|
||||
"completion_mask": torch.tensor([1, 0]),
|
||||
"old_per_token_logps": torch.tensor([0.0, 1.0]),
|
||||
"ref_per_token_logps": torch.tensor([0.0, 1.0]),
|
||||
},
|
||||
]
|
||||
padded = repad(deepcopy(sample), padding_value=999)
|
||||
|
||||
assert len(padded[0]["prompt_ids"]) == 2
|
||||
assert len(padded[0]["completion_ids"]) == 1
|
||||
|
||||
assert padded[0]["prompt_ids"].eq(999).all()
|
||||
assert padded[0]["completion_ids"].eq(999).all()
|
35
train_grpo.py
Normal file
35
train_grpo.py
Normal file
@ -0,0 +1,35 @@
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
from trl import GRPOConfig, GRPOTrainer
|
||||
|
||||
|
||||
dataset = load_dataset("trl-lib/tldr", split="train")
|
||||
|
||||
|
||||
# Define the reward function, which rewards completions that are close to 20 characters
|
||||
def reward_len(completions, **kwargs):
|
||||
return [-abs(20 - len(completion)) for completion in completions]
|
||||
|
||||
|
||||
training_args = GRPOConfig(output_dir="Qwen2-0.5B-GRPO", logging_steps=1, replay_buffer_class="SSRReplayBuffer")
|
||||
trainer = GRPOTrainer(
|
||||
model="Qwen/Qwen2-0.5B-Instruct",
|
||||
reward_funcs=reward_len,
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
)
|
||||
trainer.train()
|
@ -153,6 +153,8 @@ class GRPOConfig(TrainingArguments):
|
||||
use_liger_loss (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use the Liger GRPO loss.
|
||||
|
||||
replay_buffer_class: (`str`, defaults to `ReplayBuffer`):
|
||||
|
||||
> Parameters that control the logging
|
||||
|
||||
log_completions (`bool`, *optional*, defaults to `False`):
|
||||
@ -393,6 +395,26 @@ class GRPOConfig(TrainingArguments):
|
||||
metadata={"help": "Whether to use the Liger GRPO loss."},
|
||||
)
|
||||
|
||||
replay_buffer_class: str = field(
|
||||
default="ReplayBuffer",
|
||||
metadata={
|
||||
"help": "Replay buffer class to use, Options [ReplayBuffer, SSRReplayBuffer] The default is `ReplayBuffer`, that randomly samples without replacement."
|
||||
},
|
||||
)
|
||||
ssr_capacity_scalar: int = field(
|
||||
default=4,
|
||||
metadata={
|
||||
"help": "Scalar to multiply the replay buffer capacity. The default is 1, which means the capacity is "
|
||||
"equal to the number of training samples in the effective batch."
|
||||
},
|
||||
)
|
||||
ssr_alpha: float = field(
|
||||
default=1.0,
|
||||
metadata={
|
||||
"help": "Alpha parameter for controlling the probablity distribution of the replay buffer. The default is 1.0, "
|
||||
},
|
||||
)
|
||||
|
||||
# Parameters that control the logging
|
||||
log_completions: bool = field(
|
||||
default=False,
|
||||
|
154
trl/trainer/grpo_replay_buffer.py
Normal file
154
trl/trainer/grpo_replay_buffer.py
Normal file
@ -0,0 +1,154 @@
|
||||
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .utils import pad
|
||||
|
||||
|
||||
def repad(list_of_tensor_dicts, padding_value):
|
||||
p_ids, p_attn_masks = remove_and_pad(
|
||||
[tensor_dict["prompt_ids"] for tensor_dict in list_of_tensor_dicts],
|
||||
[tensor_dict["prompt_mask"] for tensor_dict in list_of_tensor_dicts],
|
||||
pad_token_id=padding_value,
|
||||
padding_side="left",
|
||||
)
|
||||
c_ids, c_attn_masks = remove_and_pad(
|
||||
[tensor_dict["completion_ids"] for tensor_dict in list_of_tensor_dicts],
|
||||
[tensor_dict["completion_mask"] for tensor_dict in list_of_tensor_dicts],
|
||||
pad_token_id=padding_value,
|
||||
)
|
||||
old_logps, _ = remove_and_pad(
|
||||
[tensor_dict["old_per_token_logps"] for tensor_dict in list_of_tensor_dicts],
|
||||
[tensor_dict["completion_mask"] for tensor_dict in list_of_tensor_dicts],
|
||||
pad_token_id=-10000.0, # ignored so can be anything
|
||||
)
|
||||
ref_logps, _ = remove_and_pad(
|
||||
[tensor_dict["ref_per_token_logps"] for tensor_dict in list_of_tensor_dicts],
|
||||
[tensor_dict["completion_mask"] for tensor_dict in list_of_tensor_dicts],
|
||||
pad_token_id=-10000.0, # ignored so can be anything
|
||||
)
|
||||
|
||||
for i, (p_id, p_mask, c_id, c_mask, o_logp, r_logp) in enumerate(
|
||||
zip(p_ids, p_attn_masks, c_ids, c_attn_masks, old_logps, ref_logps)
|
||||
):
|
||||
list_of_tensor_dicts[i]["prompt_ids"] = p_id
|
||||
list_of_tensor_dicts[i]["prompt_mask"] = p_mask
|
||||
list_of_tensor_dicts[i]["completion_ids"] = c_id
|
||||
list_of_tensor_dicts[i]["completion_mask"] = c_mask
|
||||
list_of_tensor_dicts[i]["old_per_token_logps"] = o_logp
|
||||
list_of_tensor_dicts[i]["ref_per_token_logps"] = r_logp
|
||||
|
||||
return list_of_tensor_dicts
|
||||
|
||||
|
||||
def remove_and_pad(list_of_ids, list_of_masks, pad_token_id=0, padding_side="right"):
|
||||
"""
|
||||
Remove padding from list_of_ids and list_of_masks, and then pad them to the same length.
|
||||
"""
|
||||
num_samples = len(list_of_ids)
|
||||
if list_of_ids[0] is None:
|
||||
# we are not using old_per_token_logps / ref_per_token_logps
|
||||
return [None] * num_samples, [None] * num_samples
|
||||
# Remove padding
|
||||
list_of_ids = [ids[mask == 1] for ids, mask in zip(list_of_ids, list_of_masks)]
|
||||
list_of_masks = [mask[mask == 1] for mask in list_of_masks]
|
||||
|
||||
ids = pad(list_of_ids, padding_value=pad_token_id, padding_side=padding_side)
|
||||
masks = pad(list_of_masks, padding_value=0, padding_side=padding_side)
|
||||
|
||||
return ids, masks
|
||||
|
||||
|
||||
def remove_padding(input_ids, attn_mask):
|
||||
"""
|
||||
Remove padding from input_ids and attn_mask.
|
||||
"""
|
||||
if attn_mask is not None:
|
||||
input_ids = input_ids[attn_mask == 1]
|
||||
attn_mask = attn_mask[attn_mask == 1]
|
||||
return input_ids, attn_mask
|
||||
|
||||
|
||||
class ReplayBuffer:
|
||||
def __init__(self, capacity):
|
||||
self.capacity = capacity
|
||||
self.buffer = []
|
||||
self.sample_indices = []
|
||||
|
||||
def add(self, experience):
|
||||
if len(self.buffer) < self.capacity:
|
||||
self.buffer.append(experience)
|
||||
else:
|
||||
self.buffer.pop(0)
|
||||
self.buffer.append(experience)
|
||||
|
||||
# Clear index queue when buffer changes
|
||||
self.sample_indices.clear()
|
||||
|
||||
def _init_sampling_queue(self):
|
||||
self.sample_indices = list(range(len(self.buffer)))
|
||||
random.shuffle(self.sample_indices)
|
||||
|
||||
def sample(self, batch_size):
|
||||
if not self.sample_indices:
|
||||
self._init_sampling_queue()
|
||||
|
||||
batch = []
|
||||
while len(batch) < batch_size and self.sample_indices:
|
||||
idx = self.sample_indices.pop(0)
|
||||
batch.append(self.buffer[idx])
|
||||
|
||||
if len(batch) != batch_size:
|
||||
raise ValueError("Not enough samples in the buffer to fill the batch.")
|
||||
|
||||
return batch
|
||||
|
||||
def __len__(self):
|
||||
return len(self.buffer)
|
||||
|
||||
|
||||
class SSRReplayBuffer(ReplayBuffer):
|
||||
# implementation of the SSR replay buffer from https://arxiv.org/pdf/2504.08837
|
||||
def __init__(self, capacity, alpha=1.0):
|
||||
super().__init__(capacity)
|
||||
self.alpha = alpha
|
||||
self.advantages = []
|
||||
|
||||
def add(self, experience):
|
||||
EPS = 0.0001 # ensures we get non-zero advs when the buffer contains all 0 advantages
|
||||
advantage = experience["advantages"].item()
|
||||
if len(self.buffer) < self.capacity:
|
||||
self.buffer.append(experience)
|
||||
self.advantages.append(abs(advantage) + EPS) # Store absolute advantage
|
||||
else:
|
||||
# Replace the oldest entry if the buffer is full
|
||||
self.buffer.pop(0)
|
||||
self.advantages.pop(0)
|
||||
self.buffer.append(experience)
|
||||
self.advantages.append(abs(advantage))
|
||||
|
||||
def sample(self, batch_size):
|
||||
if not self.buffer:
|
||||
raise ValueError("Buffer is empty. Cannot sample from an empty buffer.")
|
||||
|
||||
# Convert advantages to priorities
|
||||
scaled_priorities = np.power(self.advantages, self.alpha)
|
||||
total_priority = np.sum(scaled_priorities)
|
||||
probabilities = scaled_priorities / total_priority
|
||||
|
||||
indices = np.random.choice(len(self.buffer), batch_size, p=probabilities)
|
||||
return [self.buffer[i] for i in indices]
|
@ -51,6 +51,7 @@ from ..import_utils import is_liger_kernel_available, is_rich_available, is_vllm
|
||||
from ..models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation
|
||||
from .callbacks import SyncRefModelCallback
|
||||
from .grpo_config import GRPOConfig
|
||||
from .grpo_replay_buffer import ReplayBuffer, SSRReplayBuffer, repad
|
||||
from .utils import (
|
||||
disable_dropout_in_model,
|
||||
generate_model_card,
|
||||
@ -219,6 +220,31 @@ def split_tensor_dict(
|
||||
]
|
||||
|
||||
|
||||
def combine_tensor_dict(split_dicts: list[dict[str, Optional[torch.Tensor]]]) -> dict[str, Optional[torch.Tensor]]:
|
||||
"""
|
||||
Combines a list of dictionaries containing tensors into a single dictionary by
|
||||
concatenating the tensors along the first dimension.
|
||||
|
||||
Example:
|
||||
>>> d1 = {"x": torch.tensor([[0, 1], [2, 3]]), "y": torch.tensor([[0], [1]])}
|
||||
>>> d2 = {"x": torch.tensor([[4, 5], [6, 7]]), "y": torch.tensor([[2], [3]])}
|
||||
>>> d3 = {"x": torch.tensor([[8, 9], [10, 11]]), "y": torch.tensor([[4], [5]])}
|
||||
>>> combine_tensor_dict([d1, d2, d3])
|
||||
{
|
||||
"x": tensor([[ 0, 1], [ 2, 3], [ 4, 5], [ 6, 7], [ 8, 9], [10, 11]]),
|
||||
"y": tensor([[0], [1], [2], [3], [4], [5]])
|
||||
}
|
||||
"""
|
||||
combined_dict = {}
|
||||
keys = split_dicts[0].keys()
|
||||
|
||||
for key in keys:
|
||||
tensors = [d[key] for d in split_dicts if d[key] is not None]
|
||||
combined_dict[key] = torch.stack(tensors, dim=0) if tensors else None
|
||||
|
||||
return combined_dict
|
||||
|
||||
|
||||
def nanmin(tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Compute the minimum value of a tensor, ignoring NaNs. This function only supports 1D tensors.
|
||||
@ -673,6 +699,22 @@ class GRPOTrainer(Trainer):
|
||||
else:
|
||||
self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True)
|
||||
|
||||
# for the standard setting, use this replay buffer
|
||||
|
||||
effective_batch_size = self.args.per_device_train_batch_size * self.args.gradient_accumulation_steps
|
||||
|
||||
if self.args.replay_buffer_class == "ReplayBuffer":
|
||||
self.replay_buffer = ReplayBuffer(capacity=effective_batch_size)
|
||||
elif self.args.replay_buffer_class == "SSRReplayBuffer":
|
||||
self.replay_buffer = SSRReplayBuffer(
|
||||
capacity=effective_batch_size * self.args.ssr_capacity_scalar,
|
||||
alpha=self.args.ssr_alpha,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid `replay_buffer_class` passed to `GRPOConfig`. Expected either 'ReplayBuffer' or 'SSRReplayBuffer', but got {self.args.replay_buffer_class}."
|
||||
)
|
||||
|
||||
def _set_signature_columns_if_needed(self):
|
||||
# If `self.args.remove_unused_columns` is True, non-signature columns are removed.
|
||||
# By default, this method sets `self._signature_columns` to the model's expected inputs.
|
||||
@ -887,13 +929,19 @@ class GRPOTrainer(Trainer):
|
||||
mode = "train" if self.model.training else "eval"
|
||||
if mode == "train":
|
||||
generate_every = self.args.gradient_accumulation_steps * self.num_iterations
|
||||
if self._step % generate_every == 0 or self._buffered_inputs is None:
|
||||
if self._step % generate_every == 0 or len(self.replay_buffer) == 0:
|
||||
# self._buffered_inputs=None can occur when resuming from a checkpoint
|
||||
accumulated_local_batch = self._generate_and_score_completions(accumulated_local_batch)
|
||||
self._buffered_inputs = split_tensor_dict(
|
||||
accumulated_local_batch, self.args.gradient_accumulation_steps
|
||||
)
|
||||
inputs = self._buffered_inputs[self._step % self.args.gradient_accumulation_steps]
|
||||
effective_batch_size = self.args.per_device_train_batch_size * self.args.gradient_accumulation_steps
|
||||
split_tensors = split_tensor_dict(accumulated_local_batch, effective_batch_size)
|
||||
|
||||
for tensor in split_tensors:
|
||||
self.replay_buffer.add(tensor)
|
||||
|
||||
split_inputs = self.replay_buffer.sample(self.args.per_device_train_batch_size)
|
||||
repadded_split_inputs = repad(split_inputs, padding_value=self.processing_class.pad_token_id)
|
||||
inputs = combine_tensor_dict(repadded_split_inputs)
|
||||
|
||||
self._step += 1
|
||||
else:
|
||||
# In evaluation, there is neither gradient accumulation, nor multiple iterations
|
||||
|
Reference in New Issue
Block a user