mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
🏷️ Account for token_type_ids
in DataCollatorForVisionLanguageModeling
(#4190)
This commit is contained in:
committed by
GitHub
parent
824ff8c73e
commit
d1d0407d3c
@ -1441,6 +1441,38 @@ class TestSFTTrainer(TrlTestCase):
|
|||||||
new_param = trainer.model.get_parameter(n)
|
new_param = trainer.model.get_parameter(n)
|
||||||
assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated"
|
assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated"
|
||||||
|
|
||||||
|
# Special case for Gemma, as it uses token_type_ids, and we need to ensure they are properly in the collator.
|
||||||
|
@require_vision
|
||||||
|
def test_train_vlm_prompt_completion_gemma(self):
|
||||||
|
# Get the dataset
|
||||||
|
dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_completion", split="train")
|
||||||
|
|
||||||
|
# Initialize the trainer
|
||||||
|
training_args = SFTConfig(
|
||||||
|
output_dir=self.tmp_dir,
|
||||||
|
max_length=None, # For VLMs, truncating can remove image tokens, leading to errors
|
||||||
|
report_to="none",
|
||||||
|
)
|
||||||
|
trainer = SFTTrainer(
|
||||||
|
model="trl-internal-testing/tiny-Gemma3ForConditionalGeneration",
|
||||||
|
args=training_args,
|
||||||
|
train_dataset=dataset,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Save the initial parameters to compare them later
|
||||||
|
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
|
||||||
|
|
||||||
|
# Train the model
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
# Check that the training loss is not None
|
||||||
|
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])
|
||||||
|
|
||||||
|
# Check the params have changed
|
||||||
|
for n, param in previous_trainable_params.items():
|
||||||
|
new_param = trainer.model.get_parameter(n)
|
||||||
|
self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated")
|
||||||
|
|
||||||
# Gemma 3n uses a timm encoder, making it difficult to create a smaller variant for testing.
|
# Gemma 3n uses a timm encoder, making it difficult to create a smaller variant for testing.
|
||||||
# To ensure coverage, we run tests on the full model but mark them as slow to exclude from default runs.
|
# To ensure coverage, we run tests on the full model but mark them as slow to exclude from default runs.
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
|
@ -424,15 +424,26 @@ class DataCollatorForVisionLanguageModeling(DataCollatorMixin):
|
|||||||
input_ids = torch.cat((prompt_ids, completion_ids), dim=1)
|
input_ids = torch.cat((prompt_ids, completion_ids), dim=1)
|
||||||
attention_mask = torch.cat((prompt_mask, completion_mask), dim=1)
|
attention_mask = torch.cat((prompt_mask, completion_mask), dim=1)
|
||||||
completion_mask = torch.cat((torch.zeros_like(prompt_mask), completion_mask), dim=1)
|
completion_mask = torch.cat((torch.zeros_like(prompt_mask), completion_mask), dim=1)
|
||||||
|
if "token_type_ids" in processed_prompts: # special case for Gemma
|
||||||
|
prompt_token_type_ids = processed_prompts["token_type_ids"]
|
||||||
|
completion_token_type_ids = processed_completions["token_type_ids"]
|
||||||
|
token_type_ids = torch.cat((prompt_token_type_ids, completion_token_type_ids), dim=1)
|
||||||
|
|
||||||
# Flush left to reduce padding
|
# Flush left to reduce padding
|
||||||
attention_mask, input_ids, completion_mask = flush_left(attention_mask, input_ids, completion_mask)
|
if "token_type_ids" in processed_prompts:
|
||||||
|
attention_mask, input_ids, completion_mask, token_type_ids = flush_left(
|
||||||
|
attention_mask, input_ids, completion_mask, token_type_ids
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
attention_mask, input_ids, completion_mask = flush_left(attention_mask, input_ids, completion_mask)
|
||||||
|
|
||||||
# Truncate if necessary
|
# Truncate if necessary
|
||||||
if self.max_length is not None:
|
if self.max_length is not None:
|
||||||
input_ids = input_ids[:, : self.max_length]
|
input_ids = input_ids[:, : self.max_length]
|
||||||
attention_mask = attention_mask[:, : self.max_length]
|
attention_mask = attention_mask[:, : self.max_length]
|
||||||
completion_mask = completion_mask[:, : self.max_length]
|
completion_mask = completion_mask[:, : self.max_length]
|
||||||
|
if "token_type_ids" in processed_prompts:
|
||||||
|
token_type_ids = token_type_ids[:, : self.max_length]
|
||||||
|
|
||||||
# Create labels and mask padding tokens
|
# Create labels and mask padding tokens
|
||||||
labels = input_ids.clone()
|
labels = input_ids.clone()
|
||||||
@ -445,6 +456,8 @@ class DataCollatorForVisionLanguageModeling(DataCollatorMixin):
|
|||||||
output["input_ids"] = input_ids
|
output["input_ids"] = input_ids
|
||||||
output["attention_mask"] = attention_mask
|
output["attention_mask"] = attention_mask
|
||||||
output["labels"] = labels
|
output["labels"] = labels
|
||||||
|
if "token_type_ids" in processed_prompts:
|
||||||
|
output["token_type_ids"] = token_type_ids
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user