mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
Compare commits
5 Commits
f6e7c200c0
...
grpo-mega-
Author | SHA1 | Date | |
---|---|---|---|
506a7904d1 | |||
10d53d2302 | |||
91a61b1e27 | |||
1cceb921a5 | |||
d0cf13f2bf |
@ -25,6 +25,7 @@ import torch.utils.data
|
||||
import transformers
|
||||
from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed
|
||||
from datasets import Dataset, IterableDataset
|
||||
from torch.utils.data import DataLoader
|
||||
from packaging import version
|
||||
from torch import nn
|
||||
from torch.utils.data import Sampler
|
||||
@ -437,11 +438,9 @@ class GRPOTrainer(Trainer):
|
||||
self.num_iterations = args.num_iterations # = 𝜇 in the GRPO paper
|
||||
self.epsilon_low = args.epsilon
|
||||
self.epsilon_high = args.epsilon_high if args.epsilon_high is not None else args.epsilon
|
||||
# Tracks the number of iterations (forward + backward passes), including those within a grad accum cycle
|
||||
self._step = 0
|
||||
# Buffer the batch to reuse generated outputs across multiple updates. For more details, see
|
||||
# `_get_train_sampler` and `_prepare_inputs`.
|
||||
self._buffered_inputs = [None] * args.gradient_accumulation_steps
|
||||
self._buffered_inputs = []
|
||||
|
||||
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
|
||||
# input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the
|
||||
@ -585,7 +584,28 @@ class GRPOTrainer(Trainer):
|
||||
for i, reward_func in enumerate(self.reward_funcs):
|
||||
if isinstance(reward_func, PreTrainedModel):
|
||||
self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True)
|
||||
|
||||
|
||||
|
||||
mega_batch_size = (
|
||||
self.args.per_device_train_batch_size
|
||||
* self.accelerator.num_processes
|
||||
* self.args.gradient_accumulation_steps
|
||||
* self.num_generations
|
||||
)
|
||||
self._per_device_mega_batch_size = mega_batch_size // self.accelerator.num_processes
|
||||
gen_sampler = RepeatRandomSampler(
|
||||
data_source=self.train_dataset,
|
||||
mini_repeat_count=self.num_generations,
|
||||
batch_size=mega_batch_size,
|
||||
repeat_count=1, # We only need to sample once for the generation
|
||||
seed=self.args.seed,
|
||||
)
|
||||
self._gen_dataloader = DataLoader(
|
||||
self.train_dataset,
|
||||
sampler=gen_sampler,
|
||||
batch_size=mega_batch_size,
|
||||
collate_fn=lambda x: x,
|
||||
)
|
||||
def _set_signature_columns_if_needed(self):
|
||||
# If `self.args.remove_unused_columns` is True, non-signature columns are removed.
|
||||
# By default, this method sets `self._signature_columns` to the model's expected inputs.
|
||||
@ -732,22 +752,49 @@ class GRPOTrainer(Trainer):
|
||||
self.vllm_client.reset_prefix_cache()
|
||||
|
||||
@profiling_decorator
|
||||
def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
|
||||
def _prepare_inputs(self, _unused_inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
|
||||
def collate_fn(features):
|
||||
out = {}
|
||||
for key in features[0].keys():
|
||||
if type(features[0][key]) == float or type(features[0][key][0]) == float:
|
||||
out[key] = torch.tensor([f[key] for f in features], dtype=torch.float32)
|
||||
elif type(features[0][key][0]) == int:
|
||||
out[key] = torch.stack([torch.LongTensor(f[key]) for f in features], dim=0)
|
||||
else:
|
||||
raise KeyError(f"Unsupported type {type(features[0][key])} for key {key}")
|
||||
return out
|
||||
|
||||
mode = "eval" if self.control.should_evaluate else "train"
|
||||
if mode == "train":
|
||||
buffer_index = self._step % self.args.gradient_accumulation_steps
|
||||
buffered_inputs = self._buffered_inputs[buffer_index]
|
||||
if self.state.global_step % self.num_iterations == 0 or buffered_inputs is None:
|
||||
# buffered_inputs=None can occur when resuming from a checkpoint
|
||||
inputs = self._generate_and_score_completions(inputs)
|
||||
self._buffered_inputs[buffer_index] = inputs
|
||||
else:
|
||||
inputs = buffered_inputs
|
||||
self._step += 1
|
||||
if len(self._buffered_inputs) == 0:
|
||||
def repeat_generator():
|
||||
while True:
|
||||
yield from self._gen_dataloader
|
||||
|
||||
iter_dataloader = iter(repeat_generator())
|
||||
inputs = next(iter_dataloader)
|
||||
process_index = self.accelerator.process_index
|
||||
inputs = inputs[process_index*self._per_device_mega_batch_size:(process_index+1)*self._per_device_mega_batch_size]
|
||||
|
||||
generations = self._generate_and_score_completions(inputs)
|
||||
gen_dataset = Dataset.from_dict(generations)
|
||||
mini_batch_dataloader = DataLoader(
|
||||
gen_dataset,
|
||||
batch_size=self.args.per_device_train_batch_size,
|
||||
shuffle=True, # we technically don't need to shuffle due to grad acc, but we will decouple later
|
||||
drop_last=True,
|
||||
collate_fn=collate_fn,
|
||||
)
|
||||
for num_iters in range(self.args.num_iterations):
|
||||
for mini_batch in mini_batch_dataloader:
|
||||
self._buffered_inputs.append(mini_batch)
|
||||
inputs = self._buffered_inputs.pop(0)
|
||||
else:
|
||||
# In evaluation, we don't reuse completions across multiple updates, so we don't need to buffer inputs.
|
||||
inputs = self._generate_and_score_completions(inputs)
|
||||
return inputs
|
||||
return {
|
||||
k: v.to(self.accelerator.device) for k, v in inputs.items()
|
||||
}
|
||||
|
||||
def _generate_and_score_completions(
|
||||
self, inputs: dict[str, Union[torch.Tensor, Any]]
|
||||
@ -838,15 +885,9 @@ class GRPOTrainer(Trainer):
|
||||
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
|
||||
|
||||
with torch.no_grad():
|
||||
# When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip it's
|
||||
# computation here, and use per_token_logps.detach() instead.
|
||||
if self.num_iterations > 1:
|
||||
old_per_token_logps = self._get_per_token_logps(
|
||||
self.model, prompt_completion_ids, attention_mask, logits_to_keep
|
||||
)
|
||||
else:
|
||||
old_per_token_logps = None
|
||||
|
||||
old_per_token_logps = self._get_per_token_logps(
|
||||
self.model, prompt_completion_ids, attention_mask, logits_to_keep
|
||||
)
|
||||
if self.beta == 0.0:
|
||||
ref_per_token_logps = None
|
||||
elif self.ref_model is not None:
|
||||
|
Reference in New Issue
Block a user