mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
🧺 [4/N] Refactor _generate
in GRPO/RLOO: Move forward_kwargs
outside generation method (#4154)
Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Co-authored-by: YonatanGideoni <yonatan.gideoni@gmail.com> Co-authored-by: burtenshaw <ben.burtenshaw@gmail.com> Co-authored-by: sergiopaniego <sergiopaniegoblanco@gmail.com> Co-authored-by: lewtun <lewis.c.tunstall@gmail.com> Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
This commit is contained in:
committed by
GitHub
parent
f4c554da22
commit
e0eec055b4
@ -18,7 +18,7 @@ from typing import Any, Callable
|
|||||||
import torch
|
import torch
|
||||||
from accelerate.utils import gather_object
|
from accelerate.utils import gather_object
|
||||||
|
|
||||||
from ...data_utils import is_conversational
|
from ...data_utils import apply_chat_template, is_conversational
|
||||||
from ...trainer.grpo_trainer import GRPOTrainer as _GRPOTrainer
|
from ...trainer.grpo_trainer import GRPOTrainer as _GRPOTrainer
|
||||||
from ...trainer.utils import nanmax, nanmin, nanstd, pad
|
from ...trainer.utils import nanmax, nanmin, nanstd, pad
|
||||||
|
|
||||||
@ -80,13 +80,9 @@ class GFPOTrainer(_GRPOTrainer):
|
|||||||
if images is not None and all(img_list == [] for img_list in images):
|
if images is not None and all(img_list == [] for img_list in images):
|
||||||
images = None
|
images = None
|
||||||
|
|
||||||
(
|
prompt_ids_list, completion_ids_list, num_items_in_batch, sampling_per_token_logps_list = self._generate(
|
||||||
prompt_ids_list,
|
prompts, images
|
||||||
completion_ids_list,
|
)
|
||||||
num_items_in_batch,
|
|
||||||
sampling_per_token_logps_list,
|
|
||||||
forward_kwargs,
|
|
||||||
) = self._generate(prompts, images)
|
|
||||||
|
|
||||||
# Convert lists of token IDs to padded tensors
|
# Convert lists of token IDs to padded tensors
|
||||||
prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list]
|
prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list]
|
||||||
@ -112,6 +108,23 @@ class GFPOTrainer(_GRPOTrainer):
|
|||||||
# Concatenate prompt_mask with completion_mask for logit computation
|
# Concatenate prompt_mask with completion_mask for logit computation
|
||||||
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C)
|
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C)
|
||||||
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C)
|
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C)
|
||||||
|
|
||||||
|
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
|
||||||
|
batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size
|
||||||
|
|
||||||
|
num_images = [len(img_list) for img_list in images] if images is not None else None
|
||||||
|
|
||||||
|
# Get forward_kwargs for models with multimodal inputs
|
||||||
|
if images is not None:
|
||||||
|
prompts_text = [
|
||||||
|
apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts
|
||||||
|
]
|
||||||
|
prompt_inputs = self.processing_class(images=images, text=prompts_text, padding=True, return_tensors="pt")
|
||||||
|
prompt_inputs = super()._prepare_inputs(prompt_inputs)
|
||||||
|
forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]}
|
||||||
|
else:
|
||||||
|
forward_kwargs = {}
|
||||||
|
|
||||||
# If token_type_ids are used, extend them with zeros for the completion part
|
# If token_type_ids are used, extend them with zeros for the completion part
|
||||||
if "token_type_ids" in forward_kwargs:
|
if "token_type_ids" in forward_kwargs:
|
||||||
token_type_ids = forward_kwargs["token_type_ids"]
|
token_type_ids = forward_kwargs["token_type_ids"]
|
||||||
@ -119,11 +132,6 @@ class GFPOTrainer(_GRPOTrainer):
|
|||||||
[token_type_ids, token_type_ids.new_zeros(completion_ids.shape)], dim=1
|
[token_type_ids, token_type_ids.new_zeros(completion_ids.shape)], dim=1
|
||||||
)
|
)
|
||||||
|
|
||||||
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
|
|
||||||
batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size
|
|
||||||
|
|
||||||
num_images = [len(img_list) for img_list in images] if images is not None else None
|
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# If the generation and optimization steps are misaligned—i.e., if generation does not occur at the end of
|
# If the generation and optimization steps are misaligned—i.e., if generation does not occur at the end of
|
||||||
# a full optimizer step (when gradient_accumulation_steps is not a multiple of generate_every)—then the
|
# a full optimizer step (when gradient_accumulation_steps is not a multiple of generate_every)—then the
|
||||||
|
@ -1086,13 +1086,6 @@ class GRPOTrainer(BaseTrainer):
|
|||||||
maybe_apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts
|
maybe_apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts
|
||||||
]
|
]
|
||||||
|
|
||||||
if images is not None:
|
|
||||||
prompt_inputs = self.processing_class(text=prompts_text, padding=True, return_tensors="pt", **kwargs)
|
|
||||||
prompt_inputs = super()._prepare_inputs(prompt_inputs)
|
|
||||||
forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]}
|
|
||||||
else:
|
|
||||||
forward_kwargs = {}
|
|
||||||
|
|
||||||
# Generate completions using either vLLM or regular generation
|
# Generate completions using either vLLM or regular generation
|
||||||
if self.use_vllm:
|
if self.use_vllm:
|
||||||
if self.vllm_mode == "colocate" and self.args.vllm_enable_sleep_mode:
|
if self.vllm_mode == "colocate" and self.args.vllm_enable_sleep_mode:
|
||||||
@ -1307,13 +1300,13 @@ class GRPOTrainer(BaseTrainer):
|
|||||||
completion_ids = [c[m].tolist() for c, m in zip(completion_ids, completion_mask.bool())]
|
completion_ids = [c[m].tolist() for c, m in zip(completion_ids, completion_mask.bool())]
|
||||||
logprobs = None # not used in this case
|
logprobs = None # not used in this case
|
||||||
|
|
||||||
return prompt_ids, completion_ids, logprobs, forward_kwargs
|
return prompt_ids, completion_ids, logprobs
|
||||||
|
|
||||||
def _generate(self, prompts: list[str], images: Optional[list]):
|
def _generate(self, prompts: list[str], images: Optional[list]):
|
||||||
device = self.accelerator.device
|
device = self.accelerator.device
|
||||||
mode = "train" if self.model.training else "eval"
|
mode = "train" if self.model.training else "eval"
|
||||||
|
|
||||||
prompt_ids, completion_ids, logprobs, forward_kwargs = self._generate_single_turn(prompts, images)
|
prompt_ids, completion_ids, logprobs = self._generate_single_turn(prompts, images)
|
||||||
|
|
||||||
# Get completion length per sequence, used for logging
|
# Get completion length per sequence, used for logging
|
||||||
prompt_lengths = torch.tensor([len(ids) for ids in prompt_ids], device=device)
|
prompt_lengths = torch.tensor([len(ids) for ids in prompt_ids], device=device)
|
||||||
@ -1345,7 +1338,7 @@ class GRPOTrainer(BaseTrainer):
|
|||||||
self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item())
|
self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item())
|
||||||
self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item())
|
self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item())
|
||||||
|
|
||||||
return prompt_ids, completion_ids, total_completion_tokens, logprobs, forward_kwargs
|
return prompt_ids, completion_ids, total_completion_tokens, logprobs
|
||||||
|
|
||||||
def _generate_and_score_completions(
|
def _generate_and_score_completions(
|
||||||
self, inputs: list[dict[str, Union[torch.Tensor, Any]]]
|
self, inputs: list[dict[str, Union[torch.Tensor, Any]]]
|
||||||
@ -1365,13 +1358,9 @@ class GRPOTrainer(BaseTrainer):
|
|||||||
if images is not None and all(img_list == [] for img_list in images):
|
if images is not None and all(img_list == [] for img_list in images):
|
||||||
images = None
|
images = None
|
||||||
|
|
||||||
(
|
prompt_ids_list, completion_ids_list, num_items_in_batch, sampling_per_token_logps_list = self._generate(
|
||||||
prompt_ids_list,
|
prompts, images
|
||||||
completion_ids_list,
|
)
|
||||||
num_items_in_batch,
|
|
||||||
sampling_per_token_logps_list,
|
|
||||||
forward_kwargs,
|
|
||||||
) = self._generate(prompts, images)
|
|
||||||
|
|
||||||
# Convert lists of token IDs to padded tensors
|
# Convert lists of token IDs to padded tensors
|
||||||
prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list]
|
prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list]
|
||||||
@ -1397,6 +1386,23 @@ class GRPOTrainer(BaseTrainer):
|
|||||||
# Concatenate prompt_mask with completion_mask for logit computation
|
# Concatenate prompt_mask with completion_mask for logit computation
|
||||||
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C)
|
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C)
|
||||||
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C)
|
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C)
|
||||||
|
|
||||||
|
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
|
||||||
|
batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size
|
||||||
|
|
||||||
|
num_images = [len(img_list) for img_list in images] if images is not None else None
|
||||||
|
|
||||||
|
# Get forward_kwargs for models with multimodal inputs
|
||||||
|
if images is not None:
|
||||||
|
prompts_text = [
|
||||||
|
apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts
|
||||||
|
]
|
||||||
|
prompt_inputs = self.processing_class(images=images, text=prompts_text, padding=True, return_tensors="pt")
|
||||||
|
prompt_inputs = super()._prepare_inputs(prompt_inputs)
|
||||||
|
forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]}
|
||||||
|
else:
|
||||||
|
forward_kwargs = {}
|
||||||
|
|
||||||
# If token_type_ids are used, extend them with zeros for the completion part
|
# If token_type_ids are used, extend them with zeros for the completion part
|
||||||
if "token_type_ids" in forward_kwargs:
|
if "token_type_ids" in forward_kwargs:
|
||||||
token_type_ids = forward_kwargs["token_type_ids"]
|
token_type_ids = forward_kwargs["token_type_ids"]
|
||||||
@ -1404,11 +1410,6 @@ class GRPOTrainer(BaseTrainer):
|
|||||||
[token_type_ids, token_type_ids.new_zeros(completion_ids.shape)], dim=1
|
[token_type_ids, token_type_ids.new_zeros(completion_ids.shape)], dim=1
|
||||||
)
|
)
|
||||||
|
|
||||||
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
|
|
||||||
batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size
|
|
||||||
|
|
||||||
num_images = [len(img_list) for img_list in images] if images is not None else None
|
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# If the generation and optimization steps are misaligned—i.e., if generation does not occur at the end of
|
# If the generation and optimization steps are misaligned—i.e., if generation does not occur at the end of
|
||||||
# a full optimizer step (when gradient_accumulation_steps is not a multiple of generate_every)—then the
|
# a full optimizer step (when gradient_accumulation_steps is not a multiple of generate_every)—then the
|
||||||
|
@ -1082,13 +1082,6 @@ class RLOOTrainer(BaseTrainer):
|
|||||||
maybe_apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts
|
maybe_apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts
|
||||||
]
|
]
|
||||||
|
|
||||||
if images is not None:
|
|
||||||
prompt_inputs = self.processing_class(text=prompts_text, padding=True, return_tensors="pt", **kwargs)
|
|
||||||
prompt_inputs = super()._prepare_inputs(prompt_inputs)
|
|
||||||
forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]}
|
|
||||||
else:
|
|
||||||
forward_kwargs = {}
|
|
||||||
|
|
||||||
# Generate completions using either vLLM or regular generation
|
# Generate completions using either vLLM or regular generation
|
||||||
if self.use_vllm:
|
if self.use_vllm:
|
||||||
if self.vllm_mode == "colocate" and self.args.vllm_enable_sleep_mode:
|
if self.vllm_mode == "colocate" and self.args.vllm_enable_sleep_mode:
|
||||||
@ -1292,13 +1285,13 @@ class RLOOTrainer(BaseTrainer):
|
|||||||
prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool())]
|
prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool())]
|
||||||
completion_ids = [c[m].tolist() for c, m in zip(completion_ids, completion_mask.bool())]
|
completion_ids = [c[m].tolist() for c, m in zip(completion_ids, completion_mask.bool())]
|
||||||
|
|
||||||
return prompt_ids, completion_ids, forward_kwargs
|
return prompt_ids, completion_ids
|
||||||
|
|
||||||
def _generate(self, prompts: list[str], images: Optional[list]):
|
def _generate(self, prompts: list[str], images: Optional[list]):
|
||||||
device = self.accelerator.device
|
device = self.accelerator.device
|
||||||
mode = "train" if self.model.training else "eval"
|
mode = "train" if self.model.training else "eval"
|
||||||
|
|
||||||
prompt_ids, completion_ids, forward_kwargs = self._generate_single_turn(prompts, images)
|
prompt_ids, completion_ids = self._generate_single_turn(prompts, images)
|
||||||
|
|
||||||
# Get completion length per sequence, used for logging
|
# Get completion length per sequence, used for logging
|
||||||
prompt_lengths = torch.tensor([len(ids) for ids in prompt_ids], device=device)
|
prompt_lengths = torch.tensor([len(ids) for ids in prompt_ids], device=device)
|
||||||
@ -1331,7 +1324,7 @@ class RLOOTrainer(BaseTrainer):
|
|||||||
self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item())
|
self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item())
|
||||||
self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item())
|
self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item())
|
||||||
|
|
||||||
return prompt_ids, completion_ids, forward_kwargs
|
return prompt_ids, completion_ids
|
||||||
|
|
||||||
def _generate_and_score_completions(
|
def _generate_and_score_completions(
|
||||||
self, inputs: list[dict[str, Union[torch.Tensor, Any]]]
|
self, inputs: list[dict[str, Union[torch.Tensor, Any]]]
|
||||||
@ -1351,7 +1344,7 @@ class RLOOTrainer(BaseTrainer):
|
|||||||
if images is not None and all(img_list == [] for img_list in images):
|
if images is not None and all(img_list == [] for img_list in images):
|
||||||
images = None
|
images = None
|
||||||
|
|
||||||
prompt_ids_list, completion_ids_list, forward_kwargs = self._generate(prompts, images)
|
prompt_ids_list, completion_ids_list = self._generate(prompts, images)
|
||||||
|
|
||||||
# Convert lists of token IDs to padded tensors
|
# Convert lists of token IDs to padded tensors
|
||||||
prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list]
|
prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list]
|
||||||
@ -1372,6 +1365,23 @@ class RLOOTrainer(BaseTrainer):
|
|||||||
# Concatenate prompt_mask with completion_mask for logit computation
|
# Concatenate prompt_mask with completion_mask for logit computation
|
||||||
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C)
|
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C)
|
||||||
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C)
|
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C)
|
||||||
|
|
||||||
|
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
|
||||||
|
batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size
|
||||||
|
|
||||||
|
num_images = [len(img_list) for img_list in images] if images is not None else None
|
||||||
|
|
||||||
|
# Get forward_kwargs for models with multimodal inputs
|
||||||
|
if images is not None:
|
||||||
|
prompts_text = [
|
||||||
|
apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts
|
||||||
|
]
|
||||||
|
prompt_inputs = self.processing_class(images=images, text=prompts_text, padding=True, return_tensors="pt")
|
||||||
|
prompt_inputs = super()._prepare_inputs(prompt_inputs)
|
||||||
|
forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]}
|
||||||
|
else:
|
||||||
|
forward_kwargs = {}
|
||||||
|
|
||||||
# If token_type_ids are used, extend them with zeros for the completion part
|
# If token_type_ids are used, extend them with zeros for the completion part
|
||||||
if "token_type_ids" in forward_kwargs:
|
if "token_type_ids" in forward_kwargs:
|
||||||
token_type_ids = forward_kwargs["token_type_ids"]
|
token_type_ids = forward_kwargs["token_type_ids"]
|
||||||
@ -1379,11 +1389,6 @@ class RLOOTrainer(BaseTrainer):
|
|||||||
[token_type_ids, token_type_ids.new_zeros(completion_ids.shape)], dim=1
|
[token_type_ids, token_type_ids.new_zeros(completion_ids.shape)], dim=1
|
||||||
)
|
)
|
||||||
|
|
||||||
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
|
|
||||||
batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size
|
|
||||||
|
|
||||||
num_images = [len(img_list) for img_list in images] if images is not None else None
|
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# Compute the per-token log probabilities for the current model
|
# Compute the per-token log probabilities for the current model
|
||||||
old_per_token_logps, _ = self._get_per_token_logps_and_entropies(
|
old_per_token_logps, _ = self._get_per_token_logps_and_entropies(
|
||||||
|
Reference in New Issue
Block a user