mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
Add support for token_type_ids
in DPOTrainer
(#4285)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
This commit is contained in:
@ -1422,6 +1422,7 @@ class TestDPOVisionTrainer(TrlTestCase):
|
||||
# ("trl-internal-testing/tiny-PaliGemmaForConditionalGeneration",),
|
||||
("trl-internal-testing/tiny-LlavaForConditionalGeneration",),
|
||||
("trl-internal-testing/tiny-LlavaNextForConditionalGeneration",),
|
||||
("trl-internal-testing/tiny-Gemma3ForConditionalGeneration",),
|
||||
]
|
||||
)
|
||||
def test_vdpo_trainer(self, model_id):
|
||||
|
@ -177,6 +177,9 @@ class DataCollatorForPreference(DataCollatorMixin):
|
||||
if "ref_chosen_logps" in examples[0] and "ref_rejected_logps" in examples[0]:
|
||||
output["ref_chosen_logps"] = ref_chosen_logps
|
||||
output["ref_rejected_logps"] = ref_rejected_logps
|
||||
if "token_type_ids" in examples[0]:
|
||||
token_type_ids = [torch.tensor(example["token_type_ids"]) for example in examples]
|
||||
output["token_type_ids"] = pad(token_type_ids, padding_value=0, padding_side="left")
|
||||
|
||||
return output
|
||||
|
||||
@ -790,6 +793,8 @@ class DPOTrainer(BaseTrainer):
|
||||
output["pixel_attention_mask"] = processed_features["pixel_attention_mask"][0]
|
||||
if "image_sizes" in processed_features:
|
||||
output["image_sizes"] = processed_features["image_sizes"][0]
|
||||
if "token_type_ids" in processed_features:
|
||||
output["token_type_ids"] = processed_features["token_type_ids"][0]
|
||||
|
||||
return output
|
||||
|
||||
@ -804,6 +809,7 @@ class DPOTrainer(BaseTrainer):
|
||||
"chosen_input_ids",
|
||||
"rejected_input_ids",
|
||||
"image_sizes",
|
||||
"token_type_ids",
|
||||
"ref_chosen_logps",
|
||||
"ref_rejected_logps",
|
||||
]
|
||||
@ -991,6 +997,8 @@ class DPOTrainer(BaseTrainer):
|
||||
)
|
||||
if "image_sizes" in batch:
|
||||
output["image_sizes"] = torch.cat([batch["image_sizes"], batch["image_sizes"]], dim=0)
|
||||
if "token_type_ids" in batch:
|
||||
output["token_type_ids"] = torch.cat((batch["token_type_ids"], batch["token_type_ids"]))
|
||||
|
||||
# Concatenate the chosen and rejected completions
|
||||
max_completion_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1])
|
||||
@ -1516,6 +1524,9 @@ class DPOTrainer(BaseTrainer):
|
||||
# Concatenate the prompt and completion inputs
|
||||
input_ids = torch.cat((prompt_input_ids, completion_input_ids), dim=1)
|
||||
attention_mask = torch.cat((prompt_attention_mask, completion_attention_mask), dim=1)
|
||||
if "token_type_ids" in concatenated_batch:
|
||||
prompt_token_type_ids = concatenated_batch["token_type_ids"]
|
||||
token_type_ids = pad_to_length(prompt_token_type_ids, input_ids.shape[1], 0)
|
||||
# Mask the prompt but not the completion for the loss
|
||||
loss_mask = torch.cat(
|
||||
(torch.zeros_like(prompt_attention_mask), completion_attention_mask),
|
||||
@ -1528,7 +1539,12 @@ class DPOTrainer(BaseTrainer):
|
||||
# Flush left to reduce the memory usage
|
||||
# [[0, 0, x, x, x, x], -> [[x, x, x, x],
|
||||
# [0, x, x, x, 0, 0]] [x, x, x, 0]]
|
||||
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
|
||||
if "token_type_ids" in concatenated_batch:
|
||||
attention_mask, input_ids, loss_mask, token_type_ids = flush_left(
|
||||
attention_mask, input_ids, loss_mask, token_type_ids
|
||||
)
|
||||
else:
|
||||
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
|
||||
attention_mask = attention_mask[:, : self.max_length]
|
||||
input_ids = input_ids[:, : self.max_length]
|
||||
loss_mask = loss_mask[:, : self.max_length]
|
||||
@ -1536,11 +1552,22 @@ class DPOTrainer(BaseTrainer):
|
||||
# Flush right before truncating left, then flush left
|
||||
# [[0, 0, x, x, x, x], -> [[0, 0, x, x],
|
||||
# [0, x, x, x, 0, 0]] [0, x, x, x]]
|
||||
attention_mask, input_ids, loss_mask = flush_right(attention_mask, input_ids, loss_mask)
|
||||
if "token_type_ids" in concatenated_batch:
|
||||
attention_mask, input_ids, loss_mask, token_type_ids = flush_left(
|
||||
attention_mask, input_ids, loss_mask, token_type_ids
|
||||
)
|
||||
token_type_ids = token_type_ids[:, -self.max_length :]
|
||||
else:
|
||||
attention_mask, input_ids, loss_mask = flush_right(attention_mask, input_ids, loss_mask)
|
||||
input_ids = input_ids[:, -self.max_length :]
|
||||
attention_mask = attention_mask[:, -self.max_length :]
|
||||
loss_mask = loss_mask[:, -self.max_length :]
|
||||
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
|
||||
if "token_type_ids" in concatenated_batch:
|
||||
attention_mask, input_ids, loss_mask, token_type_ids = flush_left(
|
||||
attention_mask, input_ids, loss_mask, token_type_ids
|
||||
)
|
||||
else:
|
||||
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown truncation mode: '{self.truncation_mode}'. Should be one of ['keep_end', "
|
||||
@ -1550,7 +1577,15 @@ class DPOTrainer(BaseTrainer):
|
||||
# Flush left to reduce the memory usage
|
||||
# [[0, 0, x, x, x, x], -> [[x, x, x, x],
|
||||
# [0, x, x, x, 0, 0]] [x, x, x, 0]]
|
||||
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
|
||||
if "token_type_ids" in concatenated_batch:
|
||||
attention_mask, input_ids, loss_mask, token_type_ids = flush_left(
|
||||
attention_mask, input_ids, loss_mask, token_type_ids
|
||||
)
|
||||
else:
|
||||
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
|
||||
|
||||
if "token_type_ids" in concatenated_batch:
|
||||
model_kwargs["token_type_ids"] = token_type_ids
|
||||
|
||||
if self.use_logits_to_keep:
|
||||
# Compute logits_to_keep based on loss_mask pattern:
|
||||
|
Reference in New Issue
Block a user