Add support for token_type_ids in DPOTrainer (#4285)

Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
This commit is contained in:
Alexander Weers
2025-10-16 01:33:35 +02:00
committed by GitHub
parent aa25c2697c
commit 26b7c2507e
2 changed files with 40 additions and 4 deletions

View File

@ -1422,6 +1422,7 @@ class TestDPOVisionTrainer(TrlTestCase):
# ("trl-internal-testing/tiny-PaliGemmaForConditionalGeneration",),
("trl-internal-testing/tiny-LlavaForConditionalGeneration",),
("trl-internal-testing/tiny-LlavaNextForConditionalGeneration",),
("trl-internal-testing/tiny-Gemma3ForConditionalGeneration",),
]
)
def test_vdpo_trainer(self, model_id):

View File

@ -177,6 +177,9 @@ class DataCollatorForPreference(DataCollatorMixin):
if "ref_chosen_logps" in examples[0] and "ref_rejected_logps" in examples[0]:
output["ref_chosen_logps"] = ref_chosen_logps
output["ref_rejected_logps"] = ref_rejected_logps
if "token_type_ids" in examples[0]:
token_type_ids = [torch.tensor(example["token_type_ids"]) for example in examples]
output["token_type_ids"] = pad(token_type_ids, padding_value=0, padding_side="left")
return output
@ -790,6 +793,8 @@ class DPOTrainer(BaseTrainer):
output["pixel_attention_mask"] = processed_features["pixel_attention_mask"][0]
if "image_sizes" in processed_features:
output["image_sizes"] = processed_features["image_sizes"][0]
if "token_type_ids" in processed_features:
output["token_type_ids"] = processed_features["token_type_ids"][0]
return output
@ -804,6 +809,7 @@ class DPOTrainer(BaseTrainer):
"chosen_input_ids",
"rejected_input_ids",
"image_sizes",
"token_type_ids",
"ref_chosen_logps",
"ref_rejected_logps",
]
@ -991,6 +997,8 @@ class DPOTrainer(BaseTrainer):
)
if "image_sizes" in batch:
output["image_sizes"] = torch.cat([batch["image_sizes"], batch["image_sizes"]], dim=0)
if "token_type_ids" in batch:
output["token_type_ids"] = torch.cat((batch["token_type_ids"], batch["token_type_ids"]))
# Concatenate the chosen and rejected completions
max_completion_length = max(batch["chosen_input_ids"].shape[1], batch["rejected_input_ids"].shape[1])
@ -1516,6 +1524,9 @@ class DPOTrainer(BaseTrainer):
# Concatenate the prompt and completion inputs
input_ids = torch.cat((prompt_input_ids, completion_input_ids), dim=1)
attention_mask = torch.cat((prompt_attention_mask, completion_attention_mask), dim=1)
if "token_type_ids" in concatenated_batch:
prompt_token_type_ids = concatenated_batch["token_type_ids"]
token_type_ids = pad_to_length(prompt_token_type_ids, input_ids.shape[1], 0)
# Mask the prompt but not the completion for the loss
loss_mask = torch.cat(
(torch.zeros_like(prompt_attention_mask), completion_attention_mask),
@ -1528,7 +1539,12 @@ class DPOTrainer(BaseTrainer):
# Flush left to reduce the memory usage
# [[0, 0, x, x, x, x], -> [[x, x, x, x],
# [0, x, x, x, 0, 0]] [x, x, x, 0]]
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
if "token_type_ids" in concatenated_batch:
attention_mask, input_ids, loss_mask, token_type_ids = flush_left(
attention_mask, input_ids, loss_mask, token_type_ids
)
else:
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
attention_mask = attention_mask[:, : self.max_length]
input_ids = input_ids[:, : self.max_length]
loss_mask = loss_mask[:, : self.max_length]
@ -1536,11 +1552,22 @@ class DPOTrainer(BaseTrainer):
# Flush right before truncating left, then flush left
# [[0, 0, x, x, x, x], -> [[0, 0, x, x],
# [0, x, x, x, 0, 0]] [0, x, x, x]]
attention_mask, input_ids, loss_mask = flush_right(attention_mask, input_ids, loss_mask)
if "token_type_ids" in concatenated_batch:
attention_mask, input_ids, loss_mask, token_type_ids = flush_left(
attention_mask, input_ids, loss_mask, token_type_ids
)
token_type_ids = token_type_ids[:, -self.max_length :]
else:
attention_mask, input_ids, loss_mask = flush_right(attention_mask, input_ids, loss_mask)
input_ids = input_ids[:, -self.max_length :]
attention_mask = attention_mask[:, -self.max_length :]
loss_mask = loss_mask[:, -self.max_length :]
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
if "token_type_ids" in concatenated_batch:
attention_mask, input_ids, loss_mask, token_type_ids = flush_left(
attention_mask, input_ids, loss_mask, token_type_ids
)
else:
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
else:
raise ValueError(
f"Unknown truncation mode: '{self.truncation_mode}'. Should be one of ['keep_end', "
@ -1550,7 +1577,15 @@ class DPOTrainer(BaseTrainer):
# Flush left to reduce the memory usage
# [[0, 0, x, x, x, x], -> [[x, x, x, x],
# [0, x, x, x, 0, 0]] [x, x, x, 0]]
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
if "token_type_ids" in concatenated_batch:
attention_mask, input_ids, loss_mask, token_type_ids = flush_left(
attention_mask, input_ids, loss_mask, token_type_ids
)
else:
attention_mask, input_ids, loss_mask = flush_left(attention_mask, input_ids, loss_mask)
if "token_type_ids" in concatenated_batch:
model_kwargs["token_type_ids"] = token_type_ids
if self.use_logits_to_keep:
# Compute logits_to_keep based on loss_mask pattern: