Compare commits

...

16 Commits

Author SHA1 Message Date
b5f8feb4a6 Merge branch 'main' into grpo-ssr-replay-buffer 2025-05-28 07:29:16 +00:00
6b485118a7 add a DAPO replay buffer and only store new exp when abs(adv) is nonzero 2025-05-28 07:28:36 +00:00
f88569edd9 fix off by 1 on index 2025-05-27 09:23:44 +00:00
e7986290ba fix replay buffer 2025-05-27 07:34:33 +00:00
b562ec28af Merge branch 'grpo-per-batch-padding' into grpo-ssr-replay-buffer 2025-05-26 19:03:29 +00:00
631499efe8 fix bug where old_log_probs=None 2025-05-26 19:01:47 +00:00
446fd001ab save wip 2025-05-26 18:51:04 +00:00
727b3767d1 save wip 2025-05-26 18:49:46 +00:00
edbf12a061 precommit 2025-05-26 13:57:52 +00:00
f9b645d292 adds replay buffer 2025-05-26 13:54:59 +00:00
91ad02766f precommit 2025-05-26 11:37:46 +00:00
b2b285872e nits 2025-05-26 11:37:19 +00:00
7ce8661b84 adds adding for loss minibatches 2025-05-26 08:51:38 +00:00
7a17bc3fbc fix bug with attn mask pad remove 2025-05-23 13:56:03 +00:00
29007cc546 save wip 2025-05-21 15:05:57 +00:00
e7ceb78c93 adds option to return mask to pad, refactor padding 2025-05-21 14:07:20 +00:00
4 changed files with 330 additions and 16 deletions

View File

@ -24,7 +24,7 @@ from transformers.testing_utils import require_peft
from transformers.utils import is_peft_available
from trl import GRPOConfig, GRPOTrainer
from trl.trainer.grpo_trainer import RepeatSampler, shuffle_tensor_dict, split_tensor_dict
from trl.trainer.grpo_trainer import RepeatSampler, shuffle_dict_list, shuffle_tensor_dict, split_tensor_dict
from .testing_utils import require_vllm
@ -104,6 +104,40 @@ class ShuffleTensorDictTester(unittest.TestCase):
self.assertEqual(shuffled["x"].shape, x.shape)
class ShuffleDictListTester(unittest.TestCase):
def test_shuffle_preserves_length(self):
a = [1, 2, 3, 4]
b = ["a", "b", "c", "d"]
tensor_dict = {"a": a.copy(), "b": b.copy()}
shuffled = shuffle_dict_list(tensor_dict)
self.assertEqual(len(shuffled["a"]), len(a))
self.assertEqual(len(shuffled["b"]), len(b))
def test_shuffle_consistent_across_lists(self):
a = [10, 20, 30]
b = ["x", "y", "z"]
tensor_dict = {"a": a.copy(), "b": b.copy()}
shuffled = shuffle_dict_list(tensor_dict)
mapping = dict(zip(shuffled["a"], shuffled["b"]))
self.assertEqual(mapping[10], "x")
self.assertEqual(mapping[20], "y")
self.assertEqual(mapping[30], "z")
def test_none_list_remains_none(self):
a = [1, 2, 3]
tensor_dict = {"a": a.copy(), "b": None}
shuffled = shuffle_dict_list(tensor_dict)
self.assertIsNone(shuffled["b"])
self.assertEqual(len(shuffled["a"]), len(a))
class RepeatRandomSamplerTester(unittest.TestCase):
def test_sampler(self):
dataset = ["a", "b", "c", "d", "e", "f", "g"]

View File

@ -145,6 +145,15 @@ class GRPOConfig(TrainingArguments):
epsilon_high (`float` or `None`, *optional*, defaults to `None`):
Upper-bound epsilon value for clipping. If not specified, it defaults to the same value as the lower-bound
specified in argument `epsilon`. Paper [DAPO](https://huggingface.co/papers/2503.14476) recommends `0.28`.
replay_buffer_class (`str`, *optional*, defaults to `"ReplayBuffer"`):
Replay buffer class to use. Options are [`ReplayBuffer`] and [`SSRReplayBuffer`]. The default is
`"ReplayBuffer"`, which randomly samples without replacement.
ssr_capacity_scalar (`int`, *optional*, defaults to `4`):
Scalar to multiply the replay buffer capacity. The default is `1`, which means the capacity is equal to the
number of training samples in the effective batch.
ssr_alpha (`float`, *optional*, defaults to `1.0`):
Alpha parameter for controlling the probability distribution of the replay buffer. The default is `1.0`,
which means the replay buffer samples uniformly.
reward_weights (`list[float]` or `None`, *optional*, defaults to `None`):
Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are
weighted equally with weight `1.0`.
@ -417,6 +426,26 @@ class GRPOConfig(TrainingArguments):
"lower-bound specified in argument `epsilon`. Paper DAPO recommends `0.28`."
},
)
replay_buffer_class: str = field(
default="ReplayBuffer",
metadata={
"help": "Replay buffer class to use, Options [ReplayBuffer, SSRReplayBuffer] The default is `ReplayBuffer`, that randomly samples without replacement."
},
)
ssr_capacity_scalar: int = field(
default=4,
metadata={
"help": "Scalar to multiply the replay buffer capacity. The default is 1, which means the capacity is "
"equal to the number of training samples in the effective batch."
},
)
ssr_alpha: float = field(
default=1.0,
metadata={
"help": "Alpha parameter for controlling the probablity distribution of the replay buffer. The default is 1.0, "
},
)
reward_weights: Optional[list[float]] = field(
default=None,
metadata={

View File

@ -13,6 +13,7 @@
# limitations under the License.
import os
import random
import textwrap
import warnings
from collections import defaultdict, deque
@ -21,6 +22,7 @@ from contextlib import nullcontext
from typing import Any, Callable, Optional, Union
import datasets
import numpy as np
import torch
import torch.utils.data
import transformers
@ -81,6 +83,124 @@ if is_wandb_available():
RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]
class ReplayBuffer:
def __init__(self, capacity):
self.capacity = capacity
self.buffer = []
self.sample_indices = []
def add(self, experience):
if len(self.buffer) < self.capacity:
self.buffer.append(experience)
else:
self.buffer.pop(0)
self.buffer.append(experience)
# Clear index queue when buffer changes
self.sample_indices.clear()
def add_batch(self, experiences: dict[str, list[torch.Tensor]]):
"""
Add a batch of experiences to the replay buffer.
"""
first_tensor = next(tensor for tensor in experiences.values() if tensor is not None)
num_items = len(first_tensor)
for i in range(num_items):
experience = {key: tensor[i] if tensor is not None else None for key, tensor in experiences.items()}
self.add(experience)
def _init_sampling_queue(self):
self.sample_indices = list(range(len(self.buffer)))
random.shuffle(self.sample_indices)
def sample(self, batch_size):
if not self.sample_indices:
self._init_sampling_queue()
batch = []
while len(batch) < batch_size and self.sample_indices:
idx = self.sample_indices.pop(0)
batch.append(self.buffer[idx])
if len(batch) != batch_size:
raise ValueError("Not enough samples in the buffer to fill the batch.")
return {k: [d[k] for d in batch] if batch[0][k] is not None else None for k in batch[0]}
def __len__(self):
return len(self.buffer)
class SSRReplayBuffer(ReplayBuffer):
# implementation of the SSR replay buffer from https://arxiv.org/pdf/2504.08837
def __init__(self, capacity, alpha=1.0):
super().__init__(capacity)
self.alpha = alpha
self.advantages = []
def add(self, experience):
EPS = 0.0001 # ensures we get non-zero advs when the buffer contains all 0 advantages
advantage = experience["advantages"].item()
if len(self.buffer) < self.capacity:
self.buffer.append(experience)
self.advantages.append(abs(advantage) + EPS) # Store absolute advantage
elif abs(advantage) > EPS:
# Replace the oldest entry if the buffer is full and adv is non zero
self.buffer.pop(0)
self.advantages.pop(0)
self.buffer.append(experience)
self.advantages.append(abs(advantage))
def sample(self, batch_size):
if not self.buffer:
raise ValueError("Buffer is empty. Cannot sample from an empty buffer.")
# Convert advantages to priorities
scaled_priorities = np.power(self.advantages, self.alpha)
total_priority = np.sum(scaled_priorities)
probabilities = scaled_priorities / total_priority
indices = np.random.choice(len(self.buffer), batch_size, p=probabilities, replace=True)
batch = [self.buffer[i] for i in indices]
return {k: [d[k] for d in batch] if batch[0][k] is not None else None for k in batch[0]}
class DapoReplayBuffer(ReplayBuffer):
# implementation of the SSR replay buffer from https://arxiv.org/pdf/2504.08837
def __init__(self, capacity, alpha=1.0):
super().__init__(capacity)
self.alpha = alpha
self.weights = []
def add(self, experience):
EPS = 0.0001 # ensures we get non-zero advs when the buffer contains all 0 advantages
advantage = experience["advantages"].item()
if len(self.buffer) < self.capacity:
self.buffer.append(experience)
self.weights.append(1.0) # Store absolute advantage
elif abs(advantage) > EPS:
# Replace the oldest entry if the buffer is full and adv is positive
self.buffer.pop(0)
self.weights.pop(0)
self.buffer.append(experience)
self.weights.append(1.0)
def sample(self, batch_size):
if not self.buffer:
raise ValueError("Buffer is empty. Cannot sample from an empty buffer.")
# Convert advantages to priorities
scaled_priorities = np.power(self.weights, self.alpha)
total_priority = np.sum(scaled_priorities)
probabilities = scaled_priorities / total_priority
indices = np.random.choice(len(self.buffer), batch_size, p=probabilities, replace=False)
batch = [self.buffer[i] for i in indices]
return {k: [d[k] for d in batch] if batch[0][k] is not None else None for k in batch[0]}
class RepeatSampler(Sampler):
"""
Sampler that repeats the indices of a dataset in a structured manner.
@ -215,7 +335,7 @@ def split_tensor_dict(
]
"""
first_tensor = next(tensor for tensor in tensor_dict.values() if tensor is not None)
chunk_size = first_tensor.shape[0] // num_chunks
chunk_size = len(first_tensor) // num_chunks
return [
{
key: tensor[i * chunk_size : (i + 1) * chunk_size] if tensor is not None else None
@ -225,6 +345,30 @@ def split_tensor_dict(
]
def combine_tensor_dict(split_dicts: list[dict[str, Optional[torch.Tensor]]]) -> dict[str, Optional[torch.Tensor]]:
"""
Combines a list of dictionaries containing tensors into a single dictionary by
concatenating the tensors along the first dimension.
Example:
>>> d1 = {"x": torch.tensor([[0, 1], [2, 3]]), "y": torch.tensor([[0], [1]])}
>>> d2 = {"x": torch.tensor([[4, 5], [6, 7]]), "y": torch.tensor([[2], [3]])}
>>> d3 = {"x": torch.tensor([[8, 9], [10, 11]]), "y": torch.tensor([[4], [5]])}
>>> combine_tensor_dict([d1, d2, d3])
{
"x": tensor([[ 0, 1], [ 2, 3], [ 4, 5], [ 6, 7], [ 8, 9], [10, 11]]),
"y": tensor([[0], [1], [2], [3], [4], [5]])
}
"""
combined_dict = {}
keys = split_dicts[0].keys()
for key in keys:
tensors = [d[key] for d in split_dicts if d[key] is not None]
combined_dict[key] = torch.stack(tensors, dim=0) if tensors else None
return combined_dict
def shuffle_tensor_dict(tensor_dict: dict[str, Optional[torch.Tensor]]) -> dict[str, Optional[torch.Tensor]]:
"""
Shuffles a dictionary of tensors along the first dimension in unison.
@ -247,6 +391,25 @@ def shuffle_tensor_dict(tensor_dict: dict[str, Optional[torch.Tensor]]) -> dict[
return {key: tensor[permutation] if tensor is not None else None for key, tensor in tensor_dict.items()}
def shuffle_dict_list(tensor_dict: dict[str, Optional[list[Any]]]) -> dict[str, Optional[list[Any]]]:
"""
Shuffles a dictionary of lists along the first dimension in unison.
Example:
>>> x = [1, 2, 3]
>>> y = [4, 5, 6]
>>> tensor_dict = {"x": x, "y": y}
>>> shuffle_list_dict(tensor_dict)
{'x': [2, 1, 3], 'y': [5, 4, 6]}
"""
first_tensor = next(tensor for tensor in tensor_dict.values() if tensor is not None)
batch_size = len(first_tensor)
permutation = torch.randperm(batch_size)
return {
key: [tensor[i] for i in permutation] if tensor is not None else None for key, tensor in tensor_dict.items()
}
def nanmin(tensor: torch.Tensor) -> torch.Tensor:
"""
Compute the minimum value of a tensor, ignoring NaNs. This function only supports 1D tensors.
@ -528,7 +691,22 @@ class GRPOTrainer(Trainer):
self._step = 0
# Buffer the batch to reuse generated outputs across multiple updates. For more details, see
# `_get_train_sampler` and `_prepare_inputs`.
self._buffered_inputs = None
if args.replay_buffer_class == "ReplayBuffer":
self.replay_buffer = ReplayBuffer(capacity=args.generation_batch_size)
elif args.replay_buffer_class == "SSRReplayBuffer":
self.replay_buffer = SSRReplayBuffer(
capacity=args.generation_batch_size * args.ssr_capacity_scalar,
alpha=args.ssr_alpha,
)
elif args.replay_buffer_class == "DapoReplayBuffer":
self.replay_buffer = DapoReplayBuffer(
capacity=args.generation_batch_size * args.ssr_capacity_scalar,
alpha=args.ssr_alpha,
)
else:
raise ValueError(
f"Invalid `replay_buffer_class` passed to `GRPOConfig`. Expected either 'ReplayBuffer' or 'SSRReplayBuffer', but got {self.args.replay_buffer_class}."
)
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
# input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the
@ -967,12 +1145,16 @@ class GRPOTrainer(Trainer):
mode = "train" if self.model.training else "eval"
if mode == "train":
generate_every = self.args.steps_per_generation * self.num_iterations
if self._step % generate_every == 0 or self._buffered_inputs is None:
if self._step % generate_every == 0 or len(self.replay_buffer) == 0:
# self._buffered_inputs=None can occur when resuming from a checkpoint
generation_batch = self._generate_and_score_completions(generation_batch)
generation_batch = shuffle_tensor_dict(generation_batch)
self._buffered_inputs = split_tensor_dict(generation_batch, self.args.steps_per_generation)
inputs = self._buffered_inputs[self._step % self.args.steps_per_generation]
# we shuffle the generation batch to ensure that the order of prompts is randomized
# across different steps, onl relevant if the generation batch is larger than the optimization batch
generation_batch_shuffled = shuffle_dict_list(generation_batch)
self.replay_buffer.add_batch(generation_batch_shuffled)
inputs = self.replay_buffer.sample(self.args.per_device_train_batch_size)
self._step += 1
else:
# In evaluation, there is neither batch grouping for generation, nor multiple iterations, hence
@ -1254,13 +1436,45 @@ class GRPOTrainer(Trainer):
self._textual_logs["rewards"][name].extend(rewards_per_func[:, i].tolist())
self._textual_logs["advantages"].extend(all_process_advantages.tolist())
# Unpad the prompt and completion ids to optimise memory: https://github.com/huggingface/trl/pull/3495
unpadded_prompt_ids = []
unpadded_completion_ids = []
unpadded_per_token_logps = []
prompt_ids_length = prompt_ids.size(1)
logp_iter = attention_mask if old_per_token_logps is None else old_per_token_logps # dummy iterator
# get the start and end indices of the attention_mask
for p_ids, c_ids, old_logps, mask in zip(prompt_ids, completion_ids, logp_iter, attention_mask):
indices = torch.where(mask == 1)[0]
if len(indices) > 0:
start = indices[0]
end = indices[-1] + 1
if prompt_ids_length > end:
raise ValueError(
f"End index {end} exceeds prompt_ids_length {prompt_ids_length}. "
"This can happen if the attention mask is not correctly set."
)
# prompt ids were left padded
unpadded_prompt_ids.append(p_ids[start:prompt_ids_length])
# completion ids were right padded
unpadded_completion_ids.append(c_ids[: end - prompt_ids_length])
if mask[start:end].sum() != end - start:
raise ValueError(f"Attention mask from {start} to {end} does not match the expected length. ")
if old_per_token_logps is not None:
unpadded_per_token_logps.append(old_logps[: end - prompt_ids_length])
else:
# case where the attention mask is all zeros, e.g. when mask_truncated_completions is enabled
raise ValueError(
"Attention mask is all zeros"
f" for prompt {p_ids.tolist()} and completion {c_ids.tolist()}. "
)
return {
"prompt_ids": prompt_ids,
"prompt_mask": prompt_mask,
"completion_ids": completion_ids,
"completion_mask": completion_mask,
"unpadded_prompt_ids": unpadded_prompt_ids,
"unpadded_completion_ids": unpadded_completion_ids,
"unpadded_per_token_logps": unpadded_per_token_logps if old_per_token_logps is not None else None,
"advantages": advantages,
"old_per_token_logps": old_per_token_logps,
}
def compute_liger_loss(self, unwrapped_model, inputs):
@ -1312,14 +1526,42 @@ class GRPOTrainer(Trainer):
@profiling_decorator
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
prompt_ids, prompt_mask = pad(
inputs["unpadded_prompt_ids"],
padding_value=self.processing_class.pad_token_id,
padding_side="left",
return_mask=True,
)
completion_ids, completion_mask = pad(
inputs["unpadded_completion_ids"],
padding_value=self.processing_class.pad_token_id,
padding_side="right",
return_mask=True,
)
old_per_token_logps = inputs["unpadded_per_token_logps"]
if old_per_token_logps is not None:
old_per_token_logps = pad(inputs["unpadded_per_token_logps"], padding_value=0.0, padding_side="right")
padded_inputs = {
"prompt_ids": prompt_ids,
"prompt_mask": prompt_mask,
"completion_ids": completion_ids,
"completion_mask": completion_mask,
"advantages": torch.stack(inputs["advantages"]),
"old_per_token_logps": old_per_token_logps,
}
if return_outputs:
raise ValueError("The GRPOTrainer does not support returning outputs")
if self.use_liger_loss:
# Compute the loss using the liger grpo loss
unwrapped_model = self.accelerator.unwrap_model(model)
return self._forward_redirection(model, unwrapped_model, self.compute_liger_loss, unwrapped_model, inputs)
return self._forward_redirection(
model, unwrapped_model, self.compute_liger_loss, unwrapped_model, padded_inputs
)
else:
return self._compute_loss(model, inputs)
return self._compute_loss(model, padded_inputs)
def _compute_loss(self, model, inputs):
# Compute the per-token log probabilities for the model

View File

@ -420,7 +420,8 @@ def pad(
padding_value: int = 0,
padding_side: str = "right",
pad_to_multiple_of: Optional[int] = None,
) -> torch.Tensor:
return_mask: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""
Pads a list of tensors to the same shape along the first dimension.
@ -433,6 +434,8 @@ def pad(
Side on which to add padding. Must be 'left' or 'right'. Default is 'right'.
pad_to_multiple_of (`int`, *optional*, defaults to `None`):
If set will pad the sequence to a multiple of the provided value.
return_mask (`bool`, *optional*, defaults to `False`):
If True, returns an attn mask tensor.
Returns:
`torch.Tensor`:
@ -461,7 +464,9 @@ def pad(
# Create an output tensor filled with the padding value
output = torch.full((len(tensors), *output_shape), padding_value, dtype=tensors[0].dtype, device=tensors[0].device)
# Initialize mask tensor if required
if return_mask:
mask = torch.zeros((len(tensors), output_shape[0]), dtype=torch.long, device=tensors[0].device)
for i, t in enumerate(tensors):
if padding_side == "left":
seq_start = output_shape[0] - t.shape[0]
@ -474,6 +479,10 @@ def pad(
seq_slice = slice(seq_start, seq_start + t.shape[0])
slices = (seq_slice,) + tuple(slice(0, s) for s in t.shape[1:])
output[i][slices] = t
if return_mask:
mask[i, seq_slice] = 1
if return_mask:
return output, mask
return output