Compare commits

...

6 Commits

Author SHA1 Message Date
3625d7ad02 replay buffer 2025-04-17 19:01:04 +00:00
506a7904d1 remove none return 2025-04-15 12:02:51 +00:00
10d53d2302 clean comments 2025-04-15 11:51:40 +00:00
91a61b1e27 mega batches 2025-04-15 10:38:10 +00:00
1cceb921a5 add collate fn 2025-04-15 09:22:56 +00:00
d0cf13f2bf refactor buffer 2025-04-15 08:12:58 +00:00

View File

@ -25,6 +25,7 @@ import torch.utils.data
import transformers
from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed
from datasets import Dataset, IterableDataset
from torch.utils.data import DataLoader
from packaging import version
from torch import nn
from torch.utils.data import Sampler
@ -75,6 +76,74 @@ if is_wandb_available():
# rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model.
RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]
import numpy as np
import random
class ReplayBuffer:
def __init__(self, capacity):
self.capacity = capacity
self.buffer = []
self.sample_indices = []
def add(self, experience):
if len(self.buffer) < self.capacity:
self.buffer.append(experience)
else:
self.buffer.pop(0)
self.buffer.append(experience)
# Clear index queue when buffer changes
self.sample_indices.clear()
def _init_sampling_queue(self):
self.sample_indices = list(range(len(self.buffer)))
random.shuffle(self.sample_indices)
def sample(self, batch_size):
if not self.sample_indices:
self._init_sampling_queue()
batch = []
while len(batch) < batch_size and self.sample_indices:
idx = self.sample_indices.pop(0)
batch.append(self.buffer[idx])
return batch
def __len__(self):
return len(self.buffer)
class PrioritizedReplayBuffer(ReplayBuffer):
def __init__(self, capacity, alpha=1.0):
super().__init__(capacity)
self.alpha = alpha
self.advantages = []
def add(self, experience):
EPS = 0.0001 # ensures we get non-zero advs when the buffer contains all 0 advantages
advantage = experience["advantages"].item()
if len(self.buffer) < self.capacity:
self.buffer.append(experience)
self.advantages.append(abs(advantage) + EPS) # Store absolute advantage
else:
# Replace the oldest entry if the buffer is full
self.buffer.pop(0)
self.advantages.pop(0)
self.buffer.append(experience)
self.advantages.append(abs(advantage))
def sample(self, batch_size):
if not self.buffer:
raise ValueError("Buffer is empty. Cannot sample from an empty buffer.")
# Convert advantages to priorities
scaled_priorities = np.power(self.advantages, self.alpha)
total_priority = np.sum(scaled_priorities)
probabilities = scaled_priorities / total_priority
indices = np.random.choice(len(self.buffer), batch_size, p=probabilities)
return [self.buffer[i] for i in indices]
class RepeatRandomSampler(Sampler):
"""
@ -437,11 +506,9 @@ class GRPOTrainer(Trainer):
self.num_iterations = args.num_iterations # = 𝜇 in the GRPO paper
self.epsilon_low = args.epsilon
self.epsilon_high = args.epsilon_high if args.epsilon_high is not None else args.epsilon
# Tracks the number of iterations (forward + backward passes), including those within a grad accum cycle
self._step = 0
# Buffer the batch to reuse generated outputs across multiple updates. For more details, see
# `_get_train_sampler` and `_prepare_inputs`.
self._buffered_inputs = [None] * args.gradient_accumulation_steps
self._buffered_inputs = []
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
# input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the
@ -585,7 +652,36 @@ class GRPOTrainer(Trainer):
for i, reward_func in enumerate(self.reward_funcs):
if isinstance(reward_func, PreTrainedModel):
self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True)
mega_batch_size = (
self.args.per_device_train_batch_size
* self.accelerator.num_processes
* self.args.gradient_accumulation_steps
* self.num_generations
)
self._per_device_mega_batch_size = mega_batch_size // self.accelerator.num_processes
gen_sampler = RepeatRandomSampler(
data_source=self.train_dataset,
mini_repeat_count=self.num_generations,
batch_size=mega_batch_size,
repeat_count=1, # We only need to sample once for the generation
seed=self.args.seed,
)
self._gen_dataloader = DataLoader(
self.train_dataset,
sampler=gen_sampler,
batch_size=mega_batch_size,
collate_fn=lambda x: x,
)
self._replay_buffer = PrioritizedReplayBuffer(
capacity=mega_batch_size*4,
alpha=1.0,
)
self._num_samples_until_gen = 0
def _set_signature_columns_if_needed(self):
# If `self.args.remove_unused_columns` is True, non-signature columns are removed.
# By default, this method sets `self._signature_columns` to the model's expected inputs.
@ -732,18 +828,51 @@ class GRPOTrainer(Trainer):
self.vllm_client.reset_prefix_cache()
@profiling_decorator
def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
def _prepare_inputs(self, _unused_inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
def collate_fn(features):
out = {}
for key in features[0].keys():
if type(features[0][key]) == float or type(features[0][key][0]) == float:
out[key] = torch.tensor([f[key] for f in features], dtype=torch.float32)
elif type(features[0][key][0]) == int:
out[key] = torch.stack([torch.LongTensor(f[key]) for f in features], dim=0)
else:
raise KeyError(f"Unsupported type {type(features[0][key])} for key {key}")
return out
mode = "eval" if self.control.should_evaluate else "train"
if mode == "train":
buffer_index = self._step % self.args.gradient_accumulation_steps
buffered_inputs = self._buffered_inputs[buffer_index]
if self.state.global_step % self.num_iterations == 0 or buffered_inputs is None:
# buffered_inputs=None can occur when resuming from a checkpoint
inputs = self._generate_and_score_completions(inputs)
self._buffered_inputs[buffer_index] = inputs
else:
inputs = buffered_inputs
self._step += 1
if self._num_samples_until_gen == 0:
def repeat_generator():
while True:
yield from self._gen_dataloader
iter_dataloader = iter(repeat_generator())
inputs = next(iter_dataloader)
process_index = self.accelerator.process_index
inputs = inputs[process_index*self._per_device_mega_batch_size:(process_index+1)*self._per_device_mega_batch_size]
generations = self._generate_and_score_completions(inputs)
self._num_samples_until_gen = self.num_iterations * self._per_device_mega_batch_size
for i in range(self._per_device_mega_batch_size):
sample = {}
for k,v in generations.items():
sample[k] = v[i]
self._replay_buffer.add(sample)
samples = self._replay_buffer.sample(self.args.per_device_train_batch_size)
inputs = {}
for k in samples[0].keys():
# padding may change between mega batches, TODO
# TODO padding should be rewritten as the completion ids will not be padded correctino
if k == "advantages":
inputs[k] = torch.stack([sample[k] for sample in samples])
else:
inputs[k] = pad([sample[k] for sample in samples], padding_value=self.processing_class.pad_token_id)
self._num_samples_until_gen -= self.args.per_device_train_batch_size
else:
# In evaluation, we don't reuse completions across multiple updates, so we don't need to buffer inputs.
inputs = self._generate_and_score_completions(inputs)
@ -838,15 +967,9 @@ class GRPOTrainer(Trainer):
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
with torch.no_grad():
# When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip it's
# computation here, and use per_token_logps.detach() instead.
if self.num_iterations > 1:
old_per_token_logps = self._get_per_token_logps(
self.model, prompt_completion_ids, attention_mask, logits_to_keep
)
else:
old_per_token_logps = None
old_per_token_logps = self._get_per_token_logps(
self.model, prompt_completion_ids, attention_mask, logits_to_keep
)
if self.beta == 0.0:
ref_per_token_logps = None
elif self.ref_model is not None: