Merge branch 'main' into remove-fsdp1-support

This commit is contained in:
Behrooz Azarkhalili
2025-10-18 17:12:14 -07:00
committed by GitHub
4 changed files with 67 additions and 53 deletions

View File

@ -168,7 +168,7 @@ trl dpo --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
--output_dir Qwen2.5-0.5B-DPO
```
Read more about CLI in the [relevant documentation section](https://huggingface.co/docs/trl/main/en/clis) or use `--help` for more details.
Read more about CLI in the [relevant documentation section](https://huggingface.co/docs/trl/clis) or use `--help` for more details.
## Development
@ -190,7 +190,7 @@ Example:
from trl.experimental.new_trainer import NewTrainer
```
Read more in the [Experimental docs](https://huggingface.co/docs/trl/main/en/experimental).
Read more in the [Experimental docs](https://huggingface.co/docs/trl/experimental).
## Citation

View File

@ -18,7 +18,7 @@ from typing import Any, Callable
import torch
from accelerate.utils import gather_object
from ...data_utils import is_conversational
from ...data_utils import apply_chat_template, is_conversational
from ...trainer.grpo_trainer import GRPOTrainer as _GRPOTrainer
from ...trainer.utils import nanmax, nanmin, nanstd, pad
@ -80,13 +80,9 @@ class GFPOTrainer(_GRPOTrainer):
if images is not None and all(img_list == [] for img_list in images):
images = None
(
prompt_ids_list,
completion_ids_list,
num_items_in_batch,
sampling_per_token_logps_list,
forward_kwargs,
) = self._generate(prompts, images)
prompt_ids_list, completion_ids_list, num_items_in_batch, sampling_per_token_logps_list = self._generate(
prompts, images
)
# Convert lists of token IDs to padded tensors
prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list]
@ -112,6 +108,23 @@ class GFPOTrainer(_GRPOTrainer):
# Concatenate prompt_mask with completion_mask for logit computation
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C)
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C)
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size
num_images = [len(img_list) for img_list in images] if images is not None else None
# Get forward_kwargs for models with multimodal inputs
if images is not None:
prompts_text = [
apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts
]
prompt_inputs = self.processing_class(images=images, text=prompts_text, padding=True, return_tensors="pt")
prompt_inputs = super()._prepare_inputs(prompt_inputs)
forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]}
else:
forward_kwargs = {}
# If token_type_ids are used, extend them with zeros for the completion part
if "token_type_ids" in forward_kwargs:
token_type_ids = forward_kwargs["token_type_ids"]
@ -119,11 +132,6 @@ class GFPOTrainer(_GRPOTrainer):
[token_type_ids, token_type_ids.new_zeros(completion_ids.shape)], dim=1
)
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size
num_images = [len(img_list) for img_list in images] if images is not None else None
with torch.no_grad():
# If the generation and optimization steps are misaligned—i.e., if generation does not occur at the end of
# a full optimizer step (when gradient_accumulation_steps is not a multiple of generate_every)—then the

View File

@ -1046,13 +1046,6 @@ class GRPOTrainer(BaseTrainer):
maybe_apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts
]
if images is not None:
prompt_inputs = self.processing_class(text=prompts_text, padding=True, return_tensors="pt", **kwargs)
prompt_inputs = super()._prepare_inputs(prompt_inputs)
forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]}
else:
forward_kwargs = {}
# Generate completions using either vLLM or regular generation
if self.use_vllm:
if self.vllm_mode == "colocate" and self.args.vllm_enable_sleep_mode:
@ -1267,13 +1260,13 @@ class GRPOTrainer(BaseTrainer):
completion_ids = [c[m].tolist() for c, m in zip(completion_ids, completion_mask.bool())]
logprobs = None # not used in this case
return prompt_ids, completion_ids, logprobs, forward_kwargs
return prompt_ids, completion_ids, logprobs
def _generate(self, prompts: list[str], images: Optional[list]):
device = self.accelerator.device
mode = "train" if self.model.training else "eval"
prompt_ids, completion_ids, logprobs, forward_kwargs = self._generate_single_turn(prompts, images)
prompt_ids, completion_ids, logprobs = self._generate_single_turn(prompts, images)
# Get completion length per sequence, used for logging
prompt_lengths = torch.tensor([len(ids) for ids in prompt_ids], device=device)
@ -1305,7 +1298,7 @@ class GRPOTrainer(BaseTrainer):
self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item())
self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item())
return prompt_ids, completion_ids, total_completion_tokens, logprobs, forward_kwargs
return prompt_ids, completion_ids, total_completion_tokens, logprobs
def _generate_and_score_completions(
self, inputs: list[dict[str, Union[torch.Tensor, Any]]]
@ -1325,13 +1318,9 @@ class GRPOTrainer(BaseTrainer):
if images is not None and all(img_list == [] for img_list in images):
images = None
(
prompt_ids_list,
completion_ids_list,
num_items_in_batch,
sampling_per_token_logps_list,
forward_kwargs,
) = self._generate(prompts, images)
prompt_ids_list, completion_ids_list, num_items_in_batch, sampling_per_token_logps_list = self._generate(
prompts, images
)
# Convert lists of token IDs to padded tensors
prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list]
@ -1357,6 +1346,23 @@ class GRPOTrainer(BaseTrainer):
# Concatenate prompt_mask with completion_mask for logit computation
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C)
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C)
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size
num_images = [len(img_list) for img_list in images] if images is not None else None
# Get forward_kwargs for models with multimodal inputs
if images is not None:
prompts_text = [
apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts
]
prompt_inputs = self.processing_class(images=images, text=prompts_text, padding=True, return_tensors="pt")
prompt_inputs = super()._prepare_inputs(prompt_inputs)
forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]}
else:
forward_kwargs = {}
# If token_type_ids are used, extend them with zeros for the completion part
if "token_type_ids" in forward_kwargs:
token_type_ids = forward_kwargs["token_type_ids"]
@ -1364,11 +1370,6 @@ class GRPOTrainer(BaseTrainer):
[token_type_ids, token_type_ids.new_zeros(completion_ids.shape)], dim=1
)
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size
num_images = [len(img_list) for img_list in images] if images is not None else None
with torch.no_grad():
# If the generation and optimization steps are misaligned—i.e., if generation does not occur at the end of
# a full optimizer step (when gradient_accumulation_steps is not a multiple of generate_every)—then the

View File

@ -1042,13 +1042,6 @@ class RLOOTrainer(BaseTrainer):
maybe_apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts
]
if images is not None:
prompt_inputs = self.processing_class(text=prompts_text, padding=True, return_tensors="pt", **kwargs)
prompt_inputs = super()._prepare_inputs(prompt_inputs)
forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]}
else:
forward_kwargs = {}
# Generate completions using either vLLM or regular generation
if self.use_vllm:
if self.vllm_mode == "colocate" and self.args.vllm_enable_sleep_mode:
@ -1252,13 +1245,13 @@ class RLOOTrainer(BaseTrainer):
prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool())]
completion_ids = [c[m].tolist() for c, m in zip(completion_ids, completion_mask.bool())]
return prompt_ids, completion_ids, forward_kwargs
return prompt_ids, completion_ids
def _generate(self, prompts: list[str], images: Optional[list]):
device = self.accelerator.device
mode = "train" if self.model.training else "eval"
prompt_ids, completion_ids, forward_kwargs = self._generate_single_turn(prompts, images)
prompt_ids, completion_ids = self._generate_single_turn(prompts, images)
# Get completion length per sequence, used for logging
prompt_lengths = torch.tensor([len(ids) for ids in prompt_ids], device=device)
@ -1291,7 +1284,7 @@ class RLOOTrainer(BaseTrainer):
self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item())
self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item())
return prompt_ids, completion_ids, forward_kwargs
return prompt_ids, completion_ids
def _generate_and_score_completions(
self, inputs: list[dict[str, Union[torch.Tensor, Any]]]
@ -1311,7 +1304,7 @@ class RLOOTrainer(BaseTrainer):
if images is not None and all(img_list == [] for img_list in images):
images = None
prompt_ids_list, completion_ids_list, forward_kwargs = self._generate(prompts, images)
prompt_ids_list, completion_ids_list = self._generate(prompts, images)
# Convert lists of token IDs to padded tensors
prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list]
@ -1332,6 +1325,23 @@ class RLOOTrainer(BaseTrainer):
# Concatenate prompt_mask with completion_mask for logit computation
prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C)
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C)
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size
num_images = [len(img_list) for img_list in images] if images is not None else None
# Get forward_kwargs for models with multimodal inputs
if images is not None:
prompts_text = [
apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts
]
prompt_inputs = self.processing_class(images=images, text=prompts_text, padding=True, return_tensors="pt")
prompt_inputs = super()._prepare_inputs(prompt_inputs)
forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]}
else:
forward_kwargs = {}
# If token_type_ids are used, extend them with zeros for the completion part
if "token_type_ids" in forward_kwargs:
token_type_ids = forward_kwargs["token_type_ids"]
@ -1339,11 +1349,6 @@ class RLOOTrainer(BaseTrainer):
[token_type_ids, token_type_ids.new_zeros(completion_ids.shape)], dim=1
)
logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size
num_images = [len(img_list) for img_list in images] if images is not None else None
with torch.no_grad():
# Compute the per-token log probabilities for the current model
old_per_token_logps, _ = self._get_per_token_logps_and_entropies(