Compare commits

...

9 Commits

5 changed files with 395 additions and 5 deletions

131
tests/test_repad.py Normal file
View File

@ -0,0 +1,131 @@
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import deepcopy
import torch
from trl.trainer.grpo_replay_buffer import repad
PAD_TOKEN_ID = 123
def test_repad_basic_padding():
sample = [
{
"prompt_ids": torch.LongTensor([1, 2, 3]),
"prompt_mask": torch.LongTensor([1, 1, 0]),
"completion_ids": torch.LongTensor([5, 6, 7, 8]),
"completion_mask": torch.LongTensor([1, 1, 1, 0]),
"old_per_token_logps": torch.tensor([0.1, 0.2, 0.3, 0.4]),
"ref_per_token_logps": torch.tensor([0.0, -0.1, -0.2, -0.3]),
},
{
"prompt_ids": torch.LongTensor([4, 5]),
"prompt_mask": torch.LongTensor([1, 1]),
"completion_ids": torch.LongTensor([9, 10]),
"completion_mask": torch.LongTensor([1, 1]),
"old_per_token_logps": torch.tensor([-0.5, -0.6]),
"ref_per_token_logps": torch.tensor([0.5, 0.6]),
},
]
padded = repad(deepcopy(sample), padding_value=PAD_TOKEN_ID)
assert len(padded[0]["prompt_ids"]) == 2
assert len(padded[0]["completion_ids"]) == 3
for ex in padded:
# All sequences in same batch should have same length
assert len(ex["prompt_ids"]) == len(padded[0]["prompt_ids"])
assert len(ex["prompt_mask"]) == len(padded[0]["prompt_mask"])
assert len(ex["completion_ids"]) == len(padded[0]["completion_ids"])
assert len(ex["completion_mask"]) == len(padded[0]["completion_mask"])
# Mask and ids should match in shape
assert ex["prompt_ids"].shape == ex["prompt_mask"].shape
assert ex["completion_ids"].shape == ex["completion_mask"].shape
def test_repad_logps_padding():
sample = [
{
"prompt_ids": torch.LongTensor([1]),
"prompt_mask": torch.LongTensor([1]),
"completion_ids": torch.LongTensor([2, 3, 4]),
"completion_mask": torch.LongTensor([1, 1, 0]),
"old_per_token_logps": torch.tensor([-0.1, -0.2, -0.3]),
"ref_per_token_logps": torch.tensor([-0.5, -0.6, -0.7]),
},
{
"prompt_ids": torch.LongTensor([5, 6]),
"prompt_mask": torch.LongTensor([1, 1]),
"completion_ids": torch.LongTensor([7, 8]),
"completion_mask": torch.LongTensor([1, 1]),
"old_per_token_logps": torch.tensor([0.4, 0.5]),
"ref_per_token_logps": torch.tensor([0.6, 0.7]),
},
]
padded = repad(deepcopy(sample), padding_value=PAD_TOKEN_ID)
for logps in ["old_per_token_logps", "ref_per_token_logps"]:
for ex in padded:
assert len(ex[logps]) == len(padded[0][logps])
assert isinstance(ex[logps], torch.Tensor)
def test_repad_empty_masks():
sample = [
{
"prompt_ids": torch.tensor([0]),
"prompt_mask": torch.tensor([0]),
"completion_ids": torch.tensor([0]),
"completion_mask": torch.tensor([0]),
"old_per_token_logps": torch.tensor([0.0]),
"ref_per_token_logps": torch.tensor([0.0]),
},
{
"prompt_ids": torch.tensor([1]),
"prompt_mask": torch.tensor([0]),
"completion_ids": torch.tensor([1]),
"completion_mask": torch.tensor([0]),
"old_per_token_logps": torch.tensor([0.0]),
"ref_per_token_logps": torch.tensor([0.0]),
},
{
"prompt_ids": torch.tensor([1, 1]),
"prompt_mask": torch.tensor([0, 1]),
"completion_ids": torch.tensor([1, 2]),
"completion_mask": torch.tensor([1, 0]),
"old_per_token_logps": torch.tensor([0.0, 1.0]),
"ref_per_token_logps": torch.tensor([0.0, 1.0]),
},
{
"prompt_ids": torch.tensor([1, 1]),
"prompt_mask": torch.tensor([1, 1]),
"completion_ids": torch.tensor([1, 2]),
"completion_mask": torch.tensor([1, 0]),
"old_per_token_logps": torch.tensor([0.0, 1.0]),
"ref_per_token_logps": torch.tensor([0.0, 1.0]),
},
]
padded = repad(deepcopy(sample), padding_value=999)
assert len(padded[0]["prompt_ids"]) == 2
assert len(padded[0]["completion_ids"]) == 1
assert padded[0]["prompt_ids"].eq(999).all()
assert padded[0]["completion_ids"].eq(999).all()

35
train_grpo.py Normal file
View File

@ -0,0 +1,35 @@
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from datasets import load_dataset
from trl import GRPOConfig, GRPOTrainer
dataset = load_dataset("trl-lib/tldr", split="train")
# Define the reward function, which rewards completions that are close to 20 characters
def reward_len(completions, **kwargs):
return [-abs(20 - len(completion)) for completion in completions]
training_args = GRPOConfig(output_dir="Qwen2-0.5B-GRPO", logging_steps=1, replay_buffer_class="SSRReplayBuffer")
trainer = GRPOTrainer(
model="Qwen/Qwen2-0.5B-Instruct",
reward_funcs=reward_len,
args=training_args,
train_dataset=dataset,
)
trainer.train()

View File

@ -153,6 +153,8 @@ class GRPOConfig(TrainingArguments):
use_liger_loss (`bool`, *optional*, defaults to `False`):
Whether to use the Liger GRPO loss.
replay_buffer_class: (`str`, defaults to `ReplayBuffer`):
> Parameters that control the logging
log_completions (`bool`, *optional*, defaults to `False`):
@ -393,6 +395,26 @@ class GRPOConfig(TrainingArguments):
metadata={"help": "Whether to use the Liger GRPO loss."},
)
replay_buffer_class: str = field(
default="ReplayBuffer",
metadata={
"help": "Replay buffer class to use, Options [ReplayBuffer, SSRReplayBuffer] The default is `ReplayBuffer`, that randomly samples without replacement."
},
)
ssr_capacity_scalar: int = field(
default=4,
metadata={
"help": "Scalar to multiply the replay buffer capacity. The default is 1, which means the capacity is "
"equal to the number of training samples in the effective batch."
},
)
ssr_alpha: float = field(
default=1.0,
metadata={
"help": "Alpha parameter for controlling the probablity distribution of the replay buffer. The default is 1.0, "
},
)
# Parameters that control the logging
log_completions: bool = field(
default=False,

View File

@ -0,0 +1,154 @@
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import numpy as np
from .utils import pad
def repad(list_of_tensor_dicts, padding_value):
p_ids, p_attn_masks = remove_and_pad(
[tensor_dict["prompt_ids"] for tensor_dict in list_of_tensor_dicts],
[tensor_dict["prompt_mask"] for tensor_dict in list_of_tensor_dicts],
pad_token_id=padding_value,
padding_side="left",
)
c_ids, c_attn_masks = remove_and_pad(
[tensor_dict["completion_ids"] for tensor_dict in list_of_tensor_dicts],
[tensor_dict["completion_mask"] for tensor_dict in list_of_tensor_dicts],
pad_token_id=padding_value,
)
old_logps, _ = remove_and_pad(
[tensor_dict["old_per_token_logps"] for tensor_dict in list_of_tensor_dicts],
[tensor_dict["completion_mask"] for tensor_dict in list_of_tensor_dicts],
pad_token_id=-10000.0, # ignored so can be anything
)
ref_logps, _ = remove_and_pad(
[tensor_dict["ref_per_token_logps"] for tensor_dict in list_of_tensor_dicts],
[tensor_dict["completion_mask"] for tensor_dict in list_of_tensor_dicts],
pad_token_id=-10000.0, # ignored so can be anything
)
for i, (p_id, p_mask, c_id, c_mask, o_logp, r_logp) in enumerate(
zip(p_ids, p_attn_masks, c_ids, c_attn_masks, old_logps, ref_logps)
):
list_of_tensor_dicts[i]["prompt_ids"] = p_id
list_of_tensor_dicts[i]["prompt_mask"] = p_mask
list_of_tensor_dicts[i]["completion_ids"] = c_id
list_of_tensor_dicts[i]["completion_mask"] = c_mask
list_of_tensor_dicts[i]["old_per_token_logps"] = o_logp
list_of_tensor_dicts[i]["ref_per_token_logps"] = r_logp
return list_of_tensor_dicts
def remove_and_pad(list_of_ids, list_of_masks, pad_token_id=0, padding_side="right"):
"""
Remove padding from list_of_ids and list_of_masks, and then pad them to the same length.
"""
num_samples = len(list_of_ids)
if list_of_ids[0] is None:
# we are not using old_per_token_logps / ref_per_token_logps
return [None] * num_samples, [None] * num_samples
# Remove padding
list_of_ids = [ids[mask == 1] for ids, mask in zip(list_of_ids, list_of_masks)]
list_of_masks = [mask[mask == 1] for mask in list_of_masks]
ids = pad(list_of_ids, padding_value=pad_token_id, padding_side=padding_side)
masks = pad(list_of_masks, padding_value=0, padding_side=padding_side)
return ids, masks
def remove_padding(input_ids, attn_mask):
"""
Remove padding from input_ids and attn_mask.
"""
if attn_mask is not None:
input_ids = input_ids[attn_mask == 1]
attn_mask = attn_mask[attn_mask == 1]
return input_ids, attn_mask
class ReplayBuffer:
def __init__(self, capacity):
self.capacity = capacity
self.buffer = []
self.sample_indices = []
def add(self, experience):
if len(self.buffer) < self.capacity:
self.buffer.append(experience)
else:
self.buffer.pop(0)
self.buffer.append(experience)
# Clear index queue when buffer changes
self.sample_indices.clear()
def _init_sampling_queue(self):
self.sample_indices = list(range(len(self.buffer)))
random.shuffle(self.sample_indices)
def sample(self, batch_size):
if not self.sample_indices:
self._init_sampling_queue()
batch = []
while len(batch) < batch_size and self.sample_indices:
idx = self.sample_indices.pop(0)
batch.append(self.buffer[idx])
if len(batch) != batch_size:
raise ValueError("Not enough samples in the buffer to fill the batch.")
return batch
def __len__(self):
return len(self.buffer)
class SSRReplayBuffer(ReplayBuffer):
# implementation of the SSR replay buffer from https://arxiv.org/pdf/2504.08837
def __init__(self, capacity, alpha=1.0):
super().__init__(capacity)
self.alpha = alpha
self.advantages = []
def add(self, experience):
EPS = 0.0001 # ensures we get non-zero advs when the buffer contains all 0 advantages
advantage = experience["advantages"].item()
if len(self.buffer) < self.capacity:
self.buffer.append(experience)
self.advantages.append(abs(advantage) + EPS) # Store absolute advantage
else:
# Replace the oldest entry if the buffer is full
self.buffer.pop(0)
self.advantages.pop(0)
self.buffer.append(experience)
self.advantages.append(abs(advantage))
def sample(self, batch_size):
if not self.buffer:
raise ValueError("Buffer is empty. Cannot sample from an empty buffer.")
# Convert advantages to priorities
scaled_priorities = np.power(self.advantages, self.alpha)
total_priority = np.sum(scaled_priorities)
probabilities = scaled_priorities / total_priority
indices = np.random.choice(len(self.buffer), batch_size, p=probabilities)
return [self.buffer[i] for i in indices]

View File

@ -51,6 +51,7 @@ from ..import_utils import is_liger_kernel_available, is_rich_available, is_vllm
from ..models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation
from .callbacks import SyncRefModelCallback
from .grpo_config import GRPOConfig
from .grpo_replay_buffer import ReplayBuffer, SSRReplayBuffer, repad
from .utils import (
disable_dropout_in_model,
generate_model_card,
@ -219,6 +220,31 @@ def split_tensor_dict(
]
def combine_tensor_dict(split_dicts: list[dict[str, Optional[torch.Tensor]]]) -> dict[str, Optional[torch.Tensor]]:
"""
Combines a list of dictionaries containing tensors into a single dictionary by
concatenating the tensors along the first dimension.
Example:
>>> d1 = {"x": torch.tensor([[0, 1], [2, 3]]), "y": torch.tensor([[0], [1]])}
>>> d2 = {"x": torch.tensor([[4, 5], [6, 7]]), "y": torch.tensor([[2], [3]])}
>>> d3 = {"x": torch.tensor([[8, 9], [10, 11]]), "y": torch.tensor([[4], [5]])}
>>> combine_tensor_dict([d1, d2, d3])
{
"x": tensor([[ 0, 1], [ 2, 3], [ 4, 5], [ 6, 7], [ 8, 9], [10, 11]]),
"y": tensor([[0], [1], [2], [3], [4], [5]])
}
"""
combined_dict = {}
keys = split_dicts[0].keys()
for key in keys:
tensors = [d[key] for d in split_dicts if d[key] is not None]
combined_dict[key] = torch.stack(tensors, dim=0) if tensors else None
return combined_dict
def nanmin(tensor: torch.Tensor) -> torch.Tensor:
"""
Compute the minimum value of a tensor, ignoring NaNs. This function only supports 1D tensors.
@ -673,6 +699,22 @@ class GRPOTrainer(Trainer):
else:
self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True)
# for the standard setting, use this replay buffer
effective_batch_size = self.args.per_device_train_batch_size * self.args.gradient_accumulation_steps
if self.args.replay_buffer_class == "ReplayBuffer":
self.replay_buffer = ReplayBuffer(capacity=effective_batch_size)
elif self.args.replay_buffer_class == "SSRReplayBuffer":
self.replay_buffer = SSRReplayBuffer(
capacity=effective_batch_size * self.args.ssr_capacity_scalar,
alpha=self.args.ssr_alpha,
)
else:
raise ValueError(
f"Invalid `replay_buffer_class` passed to `GRPOConfig`. Expected either 'ReplayBuffer' or 'SSRReplayBuffer', but got {self.args.replay_buffer_class}."
)
def _set_signature_columns_if_needed(self):
# If `self.args.remove_unused_columns` is True, non-signature columns are removed.
# By default, this method sets `self._signature_columns` to the model's expected inputs.
@ -887,13 +929,19 @@ class GRPOTrainer(Trainer):
mode = "train" if self.model.training else "eval"
if mode == "train":
generate_every = self.args.gradient_accumulation_steps * self.num_iterations
if self._step % generate_every == 0 or self._buffered_inputs is None:
if self._step % generate_every == 0 or len(self.replay_buffer) == 0:
# self._buffered_inputs=None can occur when resuming from a checkpoint
accumulated_local_batch = self._generate_and_score_completions(accumulated_local_batch)
self._buffered_inputs = split_tensor_dict(
accumulated_local_batch, self.args.gradient_accumulation_steps
)
inputs = self._buffered_inputs[self._step % self.args.gradient_accumulation_steps]
effective_batch_size = self.args.per_device_train_batch_size * self.args.gradient_accumulation_steps
split_tensors = split_tensor_dict(accumulated_local_batch, effective_batch_size)
for tensor in split_tensors:
self.replay_buffer.add(tensor)
split_inputs = self.replay_buffer.sample(self.args.per_device_train_batch_size)
repadded_split_inputs = repad(split_inputs, padding_value=self.processing_class.pad_token_id)
inputs = combine_tensor_dict(repadded_split_inputs)
self._step += 1
else:
# In evaluation, there is neither gradient accumulation, nor multiple iterations