Compare commits

...

5 Commits

Author SHA1 Message Date
506a7904d1 remove none return 2025-04-15 12:02:51 +00:00
10d53d2302 clean comments 2025-04-15 11:51:40 +00:00
91a61b1e27 mega batches 2025-04-15 10:38:10 +00:00
1cceb921a5 add collate fn 2025-04-15 09:22:56 +00:00
d0cf13f2bf refactor buffer 2025-04-15 08:12:58 +00:00

View File

@ -25,6 +25,7 @@ import torch.utils.data
import transformers
from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed
from datasets import Dataset, IterableDataset
from torch.utils.data import DataLoader
from packaging import version
from torch import nn
from torch.utils.data import Sampler
@ -437,11 +438,9 @@ class GRPOTrainer(Trainer):
self.num_iterations = args.num_iterations # = 𝜇 in the GRPO paper
self.epsilon_low = args.epsilon
self.epsilon_high = args.epsilon_high if args.epsilon_high is not None else args.epsilon
# Tracks the number of iterations (forward + backward passes), including those within a grad accum cycle
self._step = 0
# Buffer the batch to reuse generated outputs across multiple updates. For more details, see
# `_get_train_sampler` and `_prepare_inputs`.
self._buffered_inputs = [None] * args.gradient_accumulation_steps
self._buffered_inputs = []
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
# input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the
@ -585,7 +584,28 @@ class GRPOTrainer(Trainer):
for i, reward_func in enumerate(self.reward_funcs):
if isinstance(reward_func, PreTrainedModel):
self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True)
mega_batch_size = (
self.args.per_device_train_batch_size
* self.accelerator.num_processes
* self.args.gradient_accumulation_steps
* self.num_generations
)
self._per_device_mega_batch_size = mega_batch_size // self.accelerator.num_processes
gen_sampler = RepeatRandomSampler(
data_source=self.train_dataset,
mini_repeat_count=self.num_generations,
batch_size=mega_batch_size,
repeat_count=1, # We only need to sample once for the generation
seed=self.args.seed,
)
self._gen_dataloader = DataLoader(
self.train_dataset,
sampler=gen_sampler,
batch_size=mega_batch_size,
collate_fn=lambda x: x,
)
def _set_signature_columns_if_needed(self):
# If `self.args.remove_unused_columns` is True, non-signature columns are removed.
# By default, this method sets `self._signature_columns` to the model's expected inputs.
@ -732,22 +752,49 @@ class GRPOTrainer(Trainer):
self.vllm_client.reset_prefix_cache()
@profiling_decorator
def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
def _prepare_inputs(self, _unused_inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
def collate_fn(features):
out = {}
for key in features[0].keys():
if type(features[0][key]) == float or type(features[0][key][0]) == float:
out[key] = torch.tensor([f[key] for f in features], dtype=torch.float32)
elif type(features[0][key][0]) == int:
out[key] = torch.stack([torch.LongTensor(f[key]) for f in features], dim=0)
else:
raise KeyError(f"Unsupported type {type(features[0][key])} for key {key}")
return out
mode = "eval" if self.control.should_evaluate else "train"
if mode == "train":
buffer_index = self._step % self.args.gradient_accumulation_steps
buffered_inputs = self._buffered_inputs[buffer_index]
if self.state.global_step % self.num_iterations == 0 or buffered_inputs is None:
# buffered_inputs=None can occur when resuming from a checkpoint
inputs = self._generate_and_score_completions(inputs)
self._buffered_inputs[buffer_index] = inputs
else:
inputs = buffered_inputs
self._step += 1
if len(self._buffered_inputs) == 0:
def repeat_generator():
while True:
yield from self._gen_dataloader
iter_dataloader = iter(repeat_generator())
inputs = next(iter_dataloader)
process_index = self.accelerator.process_index
inputs = inputs[process_index*self._per_device_mega_batch_size:(process_index+1)*self._per_device_mega_batch_size]
generations = self._generate_and_score_completions(inputs)
gen_dataset = Dataset.from_dict(generations)
mini_batch_dataloader = DataLoader(
gen_dataset,
batch_size=self.args.per_device_train_batch_size,
shuffle=True, # we technically don't need to shuffle due to grad acc, but we will decouple later
drop_last=True,
collate_fn=collate_fn,
)
for num_iters in range(self.args.num_iterations):
for mini_batch in mini_batch_dataloader:
self._buffered_inputs.append(mini_batch)
inputs = self._buffered_inputs.pop(0)
else:
# In evaluation, we don't reuse completions across multiple updates, so we don't need to buffer inputs.
inputs = self._generate_and_score_completions(inputs)
return inputs
return {
k: v.to(self.accelerator.device) for k, v in inputs.items()
}
def _generate_and_score_completions(
self, inputs: dict[str, Union[torch.Tensor, Any]]
@ -838,15 +885,9 @@ class GRPOTrainer(Trainer):
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
with torch.no_grad():
# When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip it's
# computation here, and use per_token_logps.detach() instead.
if self.num_iterations > 1:
old_per_token_logps = self._get_per_token_logps(
self.model, prompt_completion_ids, attention_mask, logits_to_keep
)
else:
old_per_token_logps = None
old_per_token_logps = self._get_per_token_logps(
self.model, prompt_completion_ids, attention_mask, logits_to_keep
)
if self.beta == 0.0:
ref_per_token_logps = None
elif self.ref_model is not None: