Merge branch 'refactor_generate' into tool-call

This commit is contained in:
Quentin Gallouédec
2025-09-26 11:38:58 -06:00
committed by GitHub

View File

@ -1355,7 +1355,7 @@ class GRPOTrainer(BaseTrainer):
device = self.accelerator.device
mode = "train" if self.model.training else "eval"
prompt_ids, completion_ids, completion_logprobs, forward_kwargs = self._generate_single_turn(prompts, images)
prompt_ids, completion_ids, logprobs, forward_kwargs = self._generate_single_turn(prompts, images)
completion_contents = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
completions = [[{"role": "assistant", "content": content}] for content in completion_contents]
tool_calls = [extract_tool_calls(completion) for completion in completion_contents]
@ -1377,7 +1377,7 @@ class GRPOTrainer(BaseTrainer):
agg_prompt_lengths = self.accelerator.gather(prompt_lengths)
agg_completion_lengths = self.accelerator.gather(completion_lengths)
total_prompt_tokens = agg_prompt_lengths.sum()
total_completion_tokens = agg_completion_lengths.sum()
total_completion_tokens = agg_completion_lengths.sum() # = num_items_in_batch, required for the DAPO loss
# Log the metrics
if mode == "train":
@ -1401,7 +1401,7 @@ class GRPOTrainer(BaseTrainer):
self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item())
self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item())
return prompt_ids, completion_ids, completion_logprobs, forward_kwargs
return prompt_ids, completion_ids, total_completion_tokens, logprobs, forward_kwargs
def _generate_and_score_completions(
self, inputs: list[dict[str, Union[torch.Tensor, Any]]]
@ -1418,13 +1418,13 @@ class GRPOTrainer(BaseTrainer):
else:
images = None
(prompt_ids_list, completion_ids_list, sampling_per_token_logps_list, forward_kwargs) = self._generate(
prompts, images
)
# Identify truncated sequences (not ending with EOS or PAD) before any padding is applied
eos_and_pad = [self.processing_class.eos_token_id, self.processing_class.pad_token_id]
is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device)
(
prompt_ids_list,
completion_ids_list,
num_items_in_batch,
sampling_per_token_logps_list,
forward_kwargs,
) = self._generate(prompts, images)
# Convert lists of token IDs to padded tensors
prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list]
@ -1441,10 +1441,10 @@ class GRPOTrainer(BaseTrainer):
else:
sampling_per_token_logps = None
num_items_in_batch = completion_mask.sum()
# If mask_truncated_completions is enabled, zero out truncated completions in completion_mask
if self.mask_truncated_completions:
eos_and_pad = [self.processing_class.eos_token_id, self.processing_class.pad_token_id]
is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device)
completion_mask = completion_mask * (~is_truncated).unsqueeze(1).int()
# Concatenate prompt_mask with completion_mask for logit computation