mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
Compare commits
6 Commits
aa25c2697c
...
grpo-dapo-
Author | SHA1 | Date | |
---|---|---|---|
3625d7ad02 | |||
506a7904d1 | |||
10d53d2302 | |||
91a61b1e27 | |||
1cceb921a5 | |||
d0cf13f2bf |
@ -25,6 +25,7 @@ import torch.utils.data
|
||||
import transformers
|
||||
from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed
|
||||
from datasets import Dataset, IterableDataset
|
||||
from torch.utils.data import DataLoader
|
||||
from packaging import version
|
||||
from torch import nn
|
||||
from torch.utils.data import Sampler
|
||||
@ -75,6 +76,74 @@ if is_wandb_available():
|
||||
# rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model.
|
||||
RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]
|
||||
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
|
||||
class ReplayBuffer:
|
||||
def __init__(self, capacity):
|
||||
self.capacity = capacity
|
||||
self.buffer = []
|
||||
self.sample_indices = []
|
||||
|
||||
def add(self, experience):
|
||||
if len(self.buffer) < self.capacity:
|
||||
self.buffer.append(experience)
|
||||
else:
|
||||
self.buffer.pop(0)
|
||||
self.buffer.append(experience)
|
||||
|
||||
# Clear index queue when buffer changes
|
||||
self.sample_indices.clear()
|
||||
|
||||
def _init_sampling_queue(self):
|
||||
self.sample_indices = list(range(len(self.buffer)))
|
||||
random.shuffle(self.sample_indices)
|
||||
|
||||
def sample(self, batch_size):
|
||||
if not self.sample_indices:
|
||||
self._init_sampling_queue()
|
||||
|
||||
batch = []
|
||||
while len(batch) < batch_size and self.sample_indices:
|
||||
idx = self.sample_indices.pop(0)
|
||||
batch.append(self.buffer[idx])
|
||||
|
||||
return batch
|
||||
|
||||
def __len__(self):
|
||||
return len(self.buffer)
|
||||
|
||||
class PrioritizedReplayBuffer(ReplayBuffer):
|
||||
def __init__(self, capacity, alpha=1.0):
|
||||
super().__init__(capacity)
|
||||
self.alpha = alpha
|
||||
self.advantages = []
|
||||
|
||||
def add(self, experience):
|
||||
EPS = 0.0001 # ensures we get non-zero advs when the buffer contains all 0 advantages
|
||||
advantage = experience["advantages"].item()
|
||||
if len(self.buffer) < self.capacity:
|
||||
self.buffer.append(experience)
|
||||
self.advantages.append(abs(advantage) + EPS) # Store absolute advantage
|
||||
else:
|
||||
# Replace the oldest entry if the buffer is full
|
||||
self.buffer.pop(0)
|
||||
self.advantages.pop(0)
|
||||
self.buffer.append(experience)
|
||||
self.advantages.append(abs(advantage))
|
||||
|
||||
def sample(self, batch_size):
|
||||
if not self.buffer:
|
||||
raise ValueError("Buffer is empty. Cannot sample from an empty buffer.")
|
||||
|
||||
# Convert advantages to priorities
|
||||
scaled_priorities = np.power(self.advantages, self.alpha)
|
||||
total_priority = np.sum(scaled_priorities)
|
||||
probabilities = scaled_priorities / total_priority
|
||||
|
||||
indices = np.random.choice(len(self.buffer), batch_size, p=probabilities)
|
||||
return [self.buffer[i] for i in indices]
|
||||
|
||||
class RepeatRandomSampler(Sampler):
|
||||
"""
|
||||
@ -437,11 +506,9 @@ class GRPOTrainer(Trainer):
|
||||
self.num_iterations = args.num_iterations # = 𝜇 in the GRPO paper
|
||||
self.epsilon_low = args.epsilon
|
||||
self.epsilon_high = args.epsilon_high if args.epsilon_high is not None else args.epsilon
|
||||
# Tracks the number of iterations (forward + backward passes), including those within a grad accum cycle
|
||||
self._step = 0
|
||||
# Buffer the batch to reuse generated outputs across multiple updates. For more details, see
|
||||
# `_get_train_sampler` and `_prepare_inputs`.
|
||||
self._buffered_inputs = [None] * args.gradient_accumulation_steps
|
||||
self._buffered_inputs = []
|
||||
|
||||
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
|
||||
# input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the
|
||||
@ -585,7 +652,36 @@ class GRPOTrainer(Trainer):
|
||||
for i, reward_func in enumerate(self.reward_funcs):
|
||||
if isinstance(reward_func, PreTrainedModel):
|
||||
self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True)
|
||||
|
||||
|
||||
|
||||
mega_batch_size = (
|
||||
self.args.per_device_train_batch_size
|
||||
* self.accelerator.num_processes
|
||||
* self.args.gradient_accumulation_steps
|
||||
* self.num_generations
|
||||
)
|
||||
self._per_device_mega_batch_size = mega_batch_size // self.accelerator.num_processes
|
||||
gen_sampler = RepeatRandomSampler(
|
||||
data_source=self.train_dataset,
|
||||
mini_repeat_count=self.num_generations,
|
||||
batch_size=mega_batch_size,
|
||||
repeat_count=1, # We only need to sample once for the generation
|
||||
seed=self.args.seed,
|
||||
)
|
||||
self._gen_dataloader = DataLoader(
|
||||
self.train_dataset,
|
||||
sampler=gen_sampler,
|
||||
batch_size=mega_batch_size,
|
||||
collate_fn=lambda x: x,
|
||||
)
|
||||
|
||||
|
||||
self._replay_buffer = PrioritizedReplayBuffer(
|
||||
capacity=mega_batch_size*4,
|
||||
alpha=1.0,
|
||||
)
|
||||
self._num_samples_until_gen = 0
|
||||
|
||||
def _set_signature_columns_if_needed(self):
|
||||
# If `self.args.remove_unused_columns` is True, non-signature columns are removed.
|
||||
# By default, this method sets `self._signature_columns` to the model's expected inputs.
|
||||
@ -732,18 +828,51 @@ class GRPOTrainer(Trainer):
|
||||
self.vllm_client.reset_prefix_cache()
|
||||
|
||||
@profiling_decorator
|
||||
def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
|
||||
def _prepare_inputs(self, _unused_inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
|
||||
def collate_fn(features):
|
||||
out = {}
|
||||
for key in features[0].keys():
|
||||
if type(features[0][key]) == float or type(features[0][key][0]) == float:
|
||||
out[key] = torch.tensor([f[key] for f in features], dtype=torch.float32)
|
||||
elif type(features[0][key][0]) == int:
|
||||
out[key] = torch.stack([torch.LongTensor(f[key]) for f in features], dim=0)
|
||||
else:
|
||||
raise KeyError(f"Unsupported type {type(features[0][key])} for key {key}")
|
||||
return out
|
||||
|
||||
mode = "eval" if self.control.should_evaluate else "train"
|
||||
if mode == "train":
|
||||
buffer_index = self._step % self.args.gradient_accumulation_steps
|
||||
buffered_inputs = self._buffered_inputs[buffer_index]
|
||||
if self.state.global_step % self.num_iterations == 0 or buffered_inputs is None:
|
||||
# buffered_inputs=None can occur when resuming from a checkpoint
|
||||
inputs = self._generate_and_score_completions(inputs)
|
||||
self._buffered_inputs[buffer_index] = inputs
|
||||
else:
|
||||
inputs = buffered_inputs
|
||||
self._step += 1
|
||||
if self._num_samples_until_gen == 0:
|
||||
def repeat_generator():
|
||||
while True:
|
||||
yield from self._gen_dataloader
|
||||
|
||||
iter_dataloader = iter(repeat_generator())
|
||||
inputs = next(iter_dataloader)
|
||||
process_index = self.accelerator.process_index
|
||||
inputs = inputs[process_index*self._per_device_mega_batch_size:(process_index+1)*self._per_device_mega_batch_size]
|
||||
|
||||
generations = self._generate_and_score_completions(inputs)
|
||||
self._num_samples_until_gen = self.num_iterations * self._per_device_mega_batch_size
|
||||
|
||||
for i in range(self._per_device_mega_batch_size):
|
||||
sample = {}
|
||||
for k,v in generations.items():
|
||||
sample[k] = v[i]
|
||||
self._replay_buffer.add(sample)
|
||||
|
||||
samples = self._replay_buffer.sample(self.args.per_device_train_batch_size)
|
||||
|
||||
inputs = {}
|
||||
for k in samples[0].keys():
|
||||
# padding may change between mega batches, TODO
|
||||
# TODO padding should be rewritten as the completion ids will not be padded correctino
|
||||
if k == "advantages":
|
||||
inputs[k] = torch.stack([sample[k] for sample in samples])
|
||||
else:
|
||||
inputs[k] = pad([sample[k] for sample in samples], padding_value=self.processing_class.pad_token_id)
|
||||
|
||||
self._num_samples_until_gen -= self.args.per_device_train_batch_size
|
||||
else:
|
||||
# In evaluation, we don't reuse completions across multiple updates, so we don't need to buffer inputs.
|
||||
inputs = self._generate_and_score_completions(inputs)
|
||||
@ -838,15 +967,9 @@ class GRPOTrainer(Trainer):
|
||||
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
|
||||
|
||||
with torch.no_grad():
|
||||
# When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip it's
|
||||
# computation here, and use per_token_logps.detach() instead.
|
||||
if self.num_iterations > 1:
|
||||
old_per_token_logps = self._get_per_token_logps(
|
||||
self.model, prompt_completion_ids, attention_mask, logits_to_keep
|
||||
)
|
||||
else:
|
||||
old_per_token_logps = None
|
||||
|
||||
old_per_token_logps = self._get_per_token_logps(
|
||||
self.model, prompt_completion_ids, attention_mask, logits_to_keep
|
||||
)
|
||||
if self.beta == 0.0:
|
||||
ref_per_token_logps = None
|
||||
elif self.ref_model is not None:
|
||||
|
Reference in New Issue
Block a user