mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
🏷️ Account for token_type_ids
in DataCollatorForVisionLanguageModeling
(#4190)
This commit is contained in:
committed by
GitHub
parent
824ff8c73e
commit
d1d0407d3c
@ -1441,6 +1441,38 @@ class TestSFTTrainer(TrlTestCase):
|
||||
new_param = trainer.model.get_parameter(n)
|
||||
assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated"
|
||||
|
||||
# Special case for Gemma, as it uses token_type_ids, and we need to ensure they are properly in the collator.
|
||||
@require_vision
|
||||
def test_train_vlm_prompt_completion_gemma(self):
|
||||
# Get the dataset
|
||||
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_completion", split="train")
|
||||
|
||||
# Initialize the trainer
|
||||
training_args = SFTConfig(
|
||||
output_dir=self.tmp_dir,
|
||||
max_length=None, # For VLMs, truncating can remove image tokens, leading to errors
|
||||
report_to="none",
|
||||
)
|
||||
trainer = SFTTrainer(
|
||||
model="trl-internal-testing/tiny-Gemma3ForConditionalGeneration",
|
||||
args=training_args,
|
||||
train_dataset=dataset,
|
||||
)
|
||||
|
||||
# Save the initial parameters to compare them later
|
||||
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
||||
|
||||
# Train the model
|
||||
trainer.train()
|
||||
|
||||
# Check that the training loss is not None
|
||||
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||
|
||||
# Check the params have changed
|
||||
for n, param in previous_trainable_params.items():
|
||||
new_param = trainer.model.get_parameter(n)
|
||||
self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated")
|
||||
|
||||
# Gemma 3n uses a timm encoder, making it difficult to create a smaller variant for testing.
|
||||
# To ensure coverage, we run tests on the full model but mark them as slow to exclude from default runs.
|
||||
@pytest.mark.slow
|
||||
|
@ -424,8 +424,17 @@ class DataCollatorForVisionLanguageModeling(DataCollatorMixin):
|
||||
input_ids = torch.cat((prompt_ids, completion_ids), dim=1)
|
||||
attention_mask = torch.cat((prompt_mask, completion_mask), dim=1)
|
||||
completion_mask = torch.cat((torch.zeros_like(prompt_mask), completion_mask), dim=1)
|
||||
if "token_type_ids" in processed_prompts: # special case for Gemma
|
||||
prompt_token_type_ids = processed_prompts["token_type_ids"]
|
||||
completion_token_type_ids = processed_completions["token_type_ids"]
|
||||
token_type_ids = torch.cat((prompt_token_type_ids, completion_token_type_ids), dim=1)
|
||||
|
||||
# Flush left to reduce padding
|
||||
if "token_type_ids" in processed_prompts:
|
||||
attention_mask, input_ids, completion_mask, token_type_ids = flush_left(
|
||||
attention_mask, input_ids, completion_mask, token_type_ids
|
||||
)
|
||||
else:
|
||||
attention_mask, input_ids, completion_mask = flush_left(attention_mask, input_ids, completion_mask)
|
||||
|
||||
# Truncate if necessary
|
||||
@ -433,6 +442,8 @@ class DataCollatorForVisionLanguageModeling(DataCollatorMixin):
|
||||
input_ids = input_ids[:, : self.max_length]
|
||||
attention_mask = attention_mask[:, : self.max_length]
|
||||
completion_mask = completion_mask[:, : self.max_length]
|
||||
if "token_type_ids" in processed_prompts:
|
||||
token_type_ids = token_type_ids[:, : self.max_length]
|
||||
|
||||
# Create labels and mask padding tokens
|
||||
labels = input_ids.clone()
|
||||
@ -445,6 +456,8 @@ class DataCollatorForVisionLanguageModeling(DataCollatorMixin):
|
||||
output["input_ids"] = input_ids
|
||||
output["attention_mask"] = attention_mask
|
||||
output["labels"] = labels
|
||||
if "token_type_ids" in processed_prompts:
|
||||
output["token_type_ids"] = token_type_ids
|
||||
return output
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user