mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 10:03:51 +08:00
Merge branch 'refactor_generate' into tool-call
This commit is contained in:
@ -1355,7 +1355,7 @@ class GRPOTrainer(BaseTrainer):
|
||||
device = self.accelerator.device
|
||||
mode = "train" if self.model.training else "eval"
|
||||
|
||||
prompt_ids, completion_ids, completion_logprobs, forward_kwargs = self._generate_single_turn(prompts, images)
|
||||
prompt_ids, completion_ids, logprobs, forward_kwargs = self._generate_single_turn(prompts, images)
|
||||
completion_contents = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
|
||||
completions = [[{"role": "assistant", "content": content}] for content in completion_contents]
|
||||
tool_calls = [extract_tool_calls(completion) for completion in completion_contents]
|
||||
@ -1377,7 +1377,7 @@ class GRPOTrainer(BaseTrainer):
|
||||
agg_prompt_lengths = self.accelerator.gather(prompt_lengths)
|
||||
agg_completion_lengths = self.accelerator.gather(completion_lengths)
|
||||
total_prompt_tokens = agg_prompt_lengths.sum()
|
||||
total_completion_tokens = agg_completion_lengths.sum()
|
||||
total_completion_tokens = agg_completion_lengths.sum() # = num_items_in_batch, required for the DAPO loss
|
||||
|
||||
# Log the metrics
|
||||
if mode == "train":
|
||||
@ -1401,7 +1401,7 @@ class GRPOTrainer(BaseTrainer):
|
||||
self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item())
|
||||
self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item())
|
||||
|
||||
return prompt_ids, completion_ids, completion_logprobs, forward_kwargs
|
||||
return prompt_ids, completion_ids, total_completion_tokens, logprobs, forward_kwargs
|
||||
|
||||
def _generate_and_score_completions(
|
||||
self, inputs: list[dict[str, Union[torch.Tensor, Any]]]
|
||||
@ -1418,13 +1418,13 @@ class GRPOTrainer(BaseTrainer):
|
||||
else:
|
||||
images = None
|
||||
|
||||
(prompt_ids_list, completion_ids_list, sampling_per_token_logps_list, forward_kwargs) = self._generate(
|
||||
prompts, images
|
||||
)
|
||||
|
||||
# Identify truncated sequences (not ending with EOS or PAD) before any padding is applied
|
||||
eos_and_pad = [self.processing_class.eos_token_id, self.processing_class.pad_token_id]
|
||||
is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device)
|
||||
(
|
||||
prompt_ids_list,
|
||||
completion_ids_list,
|
||||
num_items_in_batch,
|
||||
sampling_per_token_logps_list,
|
||||
forward_kwargs,
|
||||
) = self._generate(prompts, images)
|
||||
|
||||
# Convert lists of token IDs to padded tensors
|
||||
prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list]
|
||||
@ -1441,10 +1441,10 @@ class GRPOTrainer(BaseTrainer):
|
||||
else:
|
||||
sampling_per_token_logps = None
|
||||
|
||||
num_items_in_batch = completion_mask.sum()
|
||||
|
||||
# If mask_truncated_completions is enabled, zero out truncated completions in completion_mask
|
||||
if self.mask_truncated_completions:
|
||||
eos_and_pad = [self.processing_class.eos_token_id, self.processing_class.pad_token_id]
|
||||
is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device)
|
||||
completion_mask = completion_mask * (~is_truncated).unsqueeze(1).int()
|
||||
|
||||
# Concatenate prompt_mask with completion_mask for logit computation
|
||||
|
Reference in New Issue
Block a user