From e0eec055b412c48ad754149c475a87a8fca34fb4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Fri, 17 Oct 2025 15:36:13 -0600 Subject: [PATCH] =?UTF-8?q?=F0=9F=A7=BA=20[4/N]=20Refactor=20`=5Fgenerate`?= =?UTF-8?q?=20in=20GRPO/RLOO:=20Move=20`forward=5Fkwargs`=20outside=20gene?= =?UTF-8?q?ration=20method=20(#4154)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Co-authored-by: YonatanGideoni Co-authored-by: burtenshaw Co-authored-by: sergiopaniego Co-authored-by: lewtun Co-authored-by: Kashif Rasul --- trl/experimental/gfpo/gfpo_trainer.py | 34 ++++++++++++-------- trl/trainer/grpo_trainer.py | 45 ++++++++++++++------------- trl/trainer/rloo_trainer.py | 37 ++++++++++++---------- 3 files changed, 65 insertions(+), 51 deletions(-) diff --git a/trl/experimental/gfpo/gfpo_trainer.py b/trl/experimental/gfpo/gfpo_trainer.py index 892183792..f119c35c6 100644 --- a/trl/experimental/gfpo/gfpo_trainer.py +++ b/trl/experimental/gfpo/gfpo_trainer.py @@ -18,7 +18,7 @@ from typing import Any, Callable import torch from accelerate.utils import gather_object -from ...data_utils import is_conversational +from ...data_utils import apply_chat_template, is_conversational from ...trainer.grpo_trainer import GRPOTrainer as _GRPOTrainer from ...trainer.utils import nanmax, nanmin, nanstd, pad @@ -80,13 +80,9 @@ class GFPOTrainer(_GRPOTrainer): if images is not None and all(img_list == [] for img_list in images): images = None - ( - prompt_ids_list, - completion_ids_list, - num_items_in_batch, - sampling_per_token_logps_list, - forward_kwargs, - ) = self._generate(prompts, images) + prompt_ids_list, completion_ids_list, num_items_in_batch, sampling_per_token_logps_list = self._generate( + prompts, images + ) # Convert lists of token IDs to padded tensors prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list] @@ -112,6 +108,23 @@ class GFPOTrainer(_GRPOTrainer): # Concatenate prompt_mask with completion_mask for logit computation prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C) attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C) + + logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens + batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size + + num_images = [len(img_list) for img_list in images] if images is not None else None + + # Get forward_kwargs for models with multimodal inputs + if images is not None: + prompts_text = [ + apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts + ] + prompt_inputs = self.processing_class(images=images, text=prompts_text, padding=True, return_tensors="pt") + prompt_inputs = super()._prepare_inputs(prompt_inputs) + forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]} + else: + forward_kwargs = {} + # If token_type_ids are used, extend them with zeros for the completion part if "token_type_ids" in forward_kwargs: token_type_ids = forward_kwargs["token_type_ids"] @@ -119,11 +132,6 @@ class GFPOTrainer(_GRPOTrainer): [token_type_ids, token_type_ids.new_zeros(completion_ids.shape)], dim=1 ) - logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens - batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size - - num_images = [len(img_list) for img_list in images] if images is not None else None - with torch.no_grad(): # If the generation and optimization steps are misaligned—i.e., if generation does not occur at the end of # a full optimizer step (when gradient_accumulation_steps is not a multiple of generate_every)—then the diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 352a0144e..eaeb6eb5a 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1086,13 +1086,6 @@ class GRPOTrainer(BaseTrainer): maybe_apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts ] - if images is not None: - prompt_inputs = self.processing_class(text=prompts_text, padding=True, return_tensors="pt", **kwargs) - prompt_inputs = super()._prepare_inputs(prompt_inputs) - forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]} - else: - forward_kwargs = {} - # Generate completions using either vLLM or regular generation if self.use_vllm: if self.vllm_mode == "colocate" and self.args.vllm_enable_sleep_mode: @@ -1307,13 +1300,13 @@ class GRPOTrainer(BaseTrainer): completion_ids = [c[m].tolist() for c, m in zip(completion_ids, completion_mask.bool())] logprobs = None # not used in this case - return prompt_ids, completion_ids, logprobs, forward_kwargs + return prompt_ids, completion_ids, logprobs def _generate(self, prompts: list[str], images: Optional[list]): device = self.accelerator.device mode = "train" if self.model.training else "eval" - prompt_ids, completion_ids, logprobs, forward_kwargs = self._generate_single_turn(prompts, images) + prompt_ids, completion_ids, logprobs = self._generate_single_turn(prompts, images) # Get completion length per sequence, used for logging prompt_lengths = torch.tensor([len(ids) for ids in prompt_ids], device=device) @@ -1345,7 +1338,7 @@ class GRPOTrainer(BaseTrainer): self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item()) self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item()) - return prompt_ids, completion_ids, total_completion_tokens, logprobs, forward_kwargs + return prompt_ids, completion_ids, total_completion_tokens, logprobs def _generate_and_score_completions( self, inputs: list[dict[str, Union[torch.Tensor, Any]]] @@ -1365,13 +1358,9 @@ class GRPOTrainer(BaseTrainer): if images is not None and all(img_list == [] for img_list in images): images = None - ( - prompt_ids_list, - completion_ids_list, - num_items_in_batch, - sampling_per_token_logps_list, - forward_kwargs, - ) = self._generate(prompts, images) + prompt_ids_list, completion_ids_list, num_items_in_batch, sampling_per_token_logps_list = self._generate( + prompts, images + ) # Convert lists of token IDs to padded tensors prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list] @@ -1397,6 +1386,23 @@ class GRPOTrainer(BaseTrainer): # Concatenate prompt_mask with completion_mask for logit computation prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C) attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C) + + logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens + batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size + + num_images = [len(img_list) for img_list in images] if images is not None else None + + # Get forward_kwargs for models with multimodal inputs + if images is not None: + prompts_text = [ + apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts + ] + prompt_inputs = self.processing_class(images=images, text=prompts_text, padding=True, return_tensors="pt") + prompt_inputs = super()._prepare_inputs(prompt_inputs) + forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]} + else: + forward_kwargs = {} + # If token_type_ids are used, extend them with zeros for the completion part if "token_type_ids" in forward_kwargs: token_type_ids = forward_kwargs["token_type_ids"] @@ -1404,11 +1410,6 @@ class GRPOTrainer(BaseTrainer): [token_type_ids, token_type_ids.new_zeros(completion_ids.shape)], dim=1 ) - logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens - batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size - - num_images = [len(img_list) for img_list in images] if images is not None else None - with torch.no_grad(): # If the generation and optimization steps are misaligned—i.e., if generation does not occur at the end of # a full optimizer step (when gradient_accumulation_steps is not a multiple of generate_every)—then the diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 2ab9cef69..59fc9c285 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -1082,13 +1082,6 @@ class RLOOTrainer(BaseTrainer): maybe_apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts ] - if images is not None: - prompt_inputs = self.processing_class(text=prompts_text, padding=True, return_tensors="pt", **kwargs) - prompt_inputs = super()._prepare_inputs(prompt_inputs) - forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]} - else: - forward_kwargs = {} - # Generate completions using either vLLM or regular generation if self.use_vllm: if self.vllm_mode == "colocate" and self.args.vllm_enable_sleep_mode: @@ -1292,13 +1285,13 @@ class RLOOTrainer(BaseTrainer): prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool())] completion_ids = [c[m].tolist() for c, m in zip(completion_ids, completion_mask.bool())] - return prompt_ids, completion_ids, forward_kwargs + return prompt_ids, completion_ids def _generate(self, prompts: list[str], images: Optional[list]): device = self.accelerator.device mode = "train" if self.model.training else "eval" - prompt_ids, completion_ids, forward_kwargs = self._generate_single_turn(prompts, images) + prompt_ids, completion_ids = self._generate_single_turn(prompts, images) # Get completion length per sequence, used for logging prompt_lengths = torch.tensor([len(ids) for ids in prompt_ids], device=device) @@ -1331,7 +1324,7 @@ class RLOOTrainer(BaseTrainer): self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item()) self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item()) - return prompt_ids, completion_ids, forward_kwargs + return prompt_ids, completion_ids def _generate_and_score_completions( self, inputs: list[dict[str, Union[torch.Tensor, Any]]] @@ -1351,7 +1344,7 @@ class RLOOTrainer(BaseTrainer): if images is not None and all(img_list == [] for img_list in images): images = None - prompt_ids_list, completion_ids_list, forward_kwargs = self._generate(prompts, images) + prompt_ids_list, completion_ids_list = self._generate(prompts, images) # Convert lists of token IDs to padded tensors prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list] @@ -1372,6 +1365,23 @@ class RLOOTrainer(BaseTrainer): # Concatenate prompt_mask with completion_mask for logit computation prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C) attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C) + + logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens + batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size + + num_images = [len(img_list) for img_list in images] if images is not None else None + + # Get forward_kwargs for models with multimodal inputs + if images is not None: + prompts_text = [ + apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts + ] + prompt_inputs = self.processing_class(images=images, text=prompts_text, padding=True, return_tensors="pt") + prompt_inputs = super()._prepare_inputs(prompt_inputs) + forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]} + else: + forward_kwargs = {} + # If token_type_ids are used, extend them with zeros for the completion part if "token_type_ids" in forward_kwargs: token_type_ids = forward_kwargs["token_type_ids"] @@ -1379,11 +1389,6 @@ class RLOOTrainer(BaseTrainer): [token_type_ids, token_type_ids.new_zeros(completion_ids.shape)], dim=1 ) - logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens - batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size - - num_images = [len(img_list) for img_list in images] if images is not None else None - with torch.no_grad(): # Compute the per-token log probabilities for the current model old_per_token_logps, _ = self._get_per_token_logps_and_entropies(