mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
Merge branch 'refactor_generate' into tool-call
This commit is contained in:
@ -1355,7 +1355,7 @@ class GRPOTrainer(BaseTrainer):
|
|||||||
device = self.accelerator.device
|
device = self.accelerator.device
|
||||||
mode = "train" if self.model.training else "eval"
|
mode = "train" if self.model.training else "eval"
|
||||||
|
|
||||||
prompt_ids, completion_ids, completion_logprobs, forward_kwargs = self._generate_single_turn(prompts, images)
|
prompt_ids, completion_ids, logprobs, forward_kwargs = self._generate_single_turn(prompts, images)
|
||||||
completion_contents = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
|
completion_contents = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
|
||||||
completions = [[{"role": "assistant", "content": content}] for content in completion_contents]
|
completions = [[{"role": "assistant", "content": content}] for content in completion_contents]
|
||||||
tool_calls = [extract_tool_calls(completion) for completion in completion_contents]
|
tool_calls = [extract_tool_calls(completion) for completion in completion_contents]
|
||||||
@ -1377,7 +1377,7 @@ class GRPOTrainer(BaseTrainer):
|
|||||||
agg_prompt_lengths = self.accelerator.gather(prompt_lengths)
|
agg_prompt_lengths = self.accelerator.gather(prompt_lengths)
|
||||||
agg_completion_lengths = self.accelerator.gather(completion_lengths)
|
agg_completion_lengths = self.accelerator.gather(completion_lengths)
|
||||||
total_prompt_tokens = agg_prompt_lengths.sum()
|
total_prompt_tokens = agg_prompt_lengths.sum()
|
||||||
total_completion_tokens = agg_completion_lengths.sum()
|
total_completion_tokens = agg_completion_lengths.sum() # = num_items_in_batch, required for the DAPO loss
|
||||||
|
|
||||||
# Log the metrics
|
# Log the metrics
|
||||||
if mode == "train":
|
if mode == "train":
|
||||||
@ -1401,7 +1401,7 @@ class GRPOTrainer(BaseTrainer):
|
|||||||
self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item())
|
self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item())
|
||||||
self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item())
|
self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item())
|
||||||
|
|
||||||
return prompt_ids, completion_ids, completion_logprobs, forward_kwargs
|
return prompt_ids, completion_ids, total_completion_tokens, logprobs, forward_kwargs
|
||||||
|
|
||||||
def _generate_and_score_completions(
|
def _generate_and_score_completions(
|
||||||
self, inputs: list[dict[str, Union[torch.Tensor, Any]]]
|
self, inputs: list[dict[str, Union[torch.Tensor, Any]]]
|
||||||
@ -1418,13 +1418,13 @@ class GRPOTrainer(BaseTrainer):
|
|||||||
else:
|
else:
|
||||||
images = None
|
images = None
|
||||||
|
|
||||||
(prompt_ids_list, completion_ids_list, sampling_per_token_logps_list, forward_kwargs) = self._generate(
|
(
|
||||||
prompts, images
|
prompt_ids_list,
|
||||||
)
|
completion_ids_list,
|
||||||
|
num_items_in_batch,
|
||||||
# Identify truncated sequences (not ending with EOS or PAD) before any padding is applied
|
sampling_per_token_logps_list,
|
||||||
eos_and_pad = [self.processing_class.eos_token_id, self.processing_class.pad_token_id]
|
forward_kwargs,
|
||||||
is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device)
|
) = self._generate(prompts, images)
|
||||||
|
|
||||||
# Convert lists of token IDs to padded tensors
|
# Convert lists of token IDs to padded tensors
|
||||||
prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list]
|
prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list]
|
||||||
@ -1441,10 +1441,10 @@ class GRPOTrainer(BaseTrainer):
|
|||||||
else:
|
else:
|
||||||
sampling_per_token_logps = None
|
sampling_per_token_logps = None
|
||||||
|
|
||||||
num_items_in_batch = completion_mask.sum()
|
|
||||||
|
|
||||||
# If mask_truncated_completions is enabled, zero out truncated completions in completion_mask
|
# If mask_truncated_completions is enabled, zero out truncated completions in completion_mask
|
||||||
if self.mask_truncated_completions:
|
if self.mask_truncated_completions:
|
||||||
|
eos_and_pad = [self.processing_class.eos_token_id, self.processing_class.pad_token_id]
|
||||||
|
is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device)
|
||||||
completion_mask = completion_mask * (~is_truncated).unsqueeze(1).int()
|
completion_mask = completion_mask * (~is_truncated).unsqueeze(1).int()
|
||||||
|
|
||||||
# Concatenate prompt_mask with completion_mask for logit computation
|
# Concatenate prompt_mask with completion_mask for logit computation
|
||||||
|
Reference in New Issue
Block a user