🏷️ Account for token_type_ids in DataCollatorForVisionLanguageModeling (#4190)

This commit is contained in:
Quentin Gallouédec
2025-10-08 09:34:48 -06:00
committed by GitHub
parent 824ff8c73e
commit d1d0407d3c
2 changed files with 46 additions and 1 deletions

View File

@ -1441,6 +1441,38 @@ class TestSFTTrainer(TrlTestCase):
new_param = trainer.model.get_parameter(n)
assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated"
# Special case for Gemma, as it uses token_type_ids, and we need to ensure they are properly in the collator.
@require_vision
def test_train_vlm_prompt_completion_gemma(self):
# Get the dataset
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_completion", split="train")
# Initialize the trainer
training_args = SFTConfig(
output_dir=self.tmp_dir,
max_length=None, # For VLMs, truncating can remove image tokens, leading to errors
report_to="none",
)
trainer = SFTTrainer(
model="trl-internal-testing/tiny-Gemma3ForConditionalGeneration",
args=training_args,
train_dataset=dataset,
)
# Save the initial parameters to compare them later
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
# Train the model
trainer.train()
# Check that the training loss is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
# Check the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated")
# Gemma 3n uses a timm encoder, making it difficult to create a smaller variant for testing.
# To ensure coverage, we run tests on the full model but mark them as slow to exclude from default runs.
@pytest.mark.slow

View File

@ -424,15 +424,26 @@ class DataCollatorForVisionLanguageModeling(DataCollatorMixin):
input_ids = torch.cat((prompt_ids, completion_ids), dim=1)
attention_mask = torch.cat((prompt_mask, completion_mask), dim=1)
completion_mask = torch.cat((torch.zeros_like(prompt_mask), completion_mask), dim=1)
if "token_type_ids" in processed_prompts: # special case for Gemma
prompt_token_type_ids = processed_prompts["token_type_ids"]
completion_token_type_ids = processed_completions["token_type_ids"]
token_type_ids = torch.cat((prompt_token_type_ids, completion_token_type_ids), dim=1)
# Flush left to reduce padding
attention_mask, input_ids, completion_mask = flush_left(attention_mask, input_ids, completion_mask)
if "token_type_ids" in processed_prompts:
attention_mask, input_ids, completion_mask, token_type_ids = flush_left(
attention_mask, input_ids, completion_mask, token_type_ids
)
else:
attention_mask, input_ids, completion_mask = flush_left(attention_mask, input_ids, completion_mask)
# Truncate if necessary
if self.max_length is not None:
input_ids = input_ids[:, : self.max_length]
attention_mask = attention_mask[:, : self.max_length]
completion_mask = completion_mask[:, : self.max_length]
if "token_type_ids" in processed_prompts:
token_type_ids = token_type_ids[:, : self.max_length]
# Create labels and mask padding tokens
labels = input_ids.clone()
@ -445,6 +456,8 @@ class DataCollatorForVisionLanguageModeling(DataCollatorMixin):
output["input_ids"] = input_ids
output["attention_mask"] = attention_mask
output["labels"] = labels
if "token_type_ids" in processed_prompts:
output["token_type_ids"] = token_type_ids
return output