Compare commits

...

2 Commits

3 changed files with 62 additions and 47 deletions

View File

@ -162,6 +162,18 @@ class TestDataCollatorForLanguageModeling(TrlTestCase):
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0], [0, 1, 0, 0]]))
torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3, -100], [4, 5, -100, -100]]))
def test_pad_to_multiple_of_and_padding_free(self):
"""Test padding to multiple of specified value."""
collator = DataCollatorForLanguageModeling(pad_token_id=0, padding_free=True, pad_to_multiple_of=4)
examples = [{"input_ids": [1, 2, 3]}, {"input_ids": [4, 5]}]
result = collator(examples)
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4, 5, 0, 0, 0]]))
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1, 1, 1, 0, 0, 0]]))
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0, 1, 0, 0, 0]]))
torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3, 4, 5, -100, -100, -100]]))
def test_custom_position_ids(self):
"""Test handling of custom position IDs in examples."""
self.collator = DataCollatorForLanguageModeling(pad_token_id=0)

View File

@ -663,8 +663,9 @@ def pack_dataset(
>>> dataset = Dataset.from_dict(examples)
>>> packed_dataset = pack_dataset(dataset, seq_length=4, strategy="bfd")
>>> packed_dataset[:]
{'input_ids': [[1, 2, 3, 9], [6, 7, 8, 4, 5]],
'attention_mask': [[1, 1, 0, 1], [1, 0, 0, 1, 0]]}
{'input_ids': [[1, 2, 3, 9], [6, 7, 8], [4, 5]],
'attention_mask': [[1, 1, 0, 1], [1, 0, 0], [1, 0]],
'seq_lengths': [[3, 1], [3], [2]]}
```
"""
if map_kwargs is None:

View File

@ -124,7 +124,11 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
that are no in the completion.
padding_free (`bool`, *optional*, defaults to `False`):
If set to `True`, the sequences will be flattened into a single sequence, and the position IDs will be
generated accordingly. The attention mask will be set to 1 for all tokens.
generated accordingly.
return_position_ids (`bool`, *optional*, defaults to `True`):
Whether to return position IDs. If `True`, position IDs are generated and returned. If `False`, attention
masks are generated and returned instead. Note that when using FlashAttention, this should be set to
`True`.
pad_to_multiple_of (`int` or `None`, *optional*, defaults to `None`):
If set, the sequences will be padded to a multiple of this value.
return_tensors (`str`, *optional*, defaults to `"pt"`):
@ -139,8 +143,6 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
>>> collator(examples)
{'input_ids': tensor([[ 1, 2, 3],
[ 4, 5, 0]]),
'attention_mask': tensor([[ 1, 1, 1],
[ 1, 1, 0]]),
'position_ids': tensor([[0, 1, 2],
[0, 1, 0]]),
'labels': tensor([[ 1, 2, 3],
@ -154,8 +156,6 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
>>> collator(examples)
{'input_ids': tensor([[ 1, 2, 3],
[ 4, 5, 0]]),
'attention_mask': tensor([[ 1, 1, 1],
[ 1, 1, 0]]),
'position_ids': tensor([[0, 1, 2],
[0, 1, 0]]),
'labels': tensor([[-100, 2, 3],
@ -182,14 +182,9 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
# Convert to tensor
input_ids = [torch.tensor(example["input_ids"]) for example in examples]
# Check if we have meaningful seq_lengths from packing (restarting sequences)
has_packed_position_ids = self.return_position_ids and "seq_lengths" in examples[0] and self.padding_free
# For packing with position_ids, we should NOT create attention_mask as it causes
# FlashAttention to ignore position_ids and compute wrong cu_seq_lens from the all-1s mask
if not has_packed_position_ids:
attention_mask = [torch.ones_like(input_ids) for input_ids in input_ids]
# In practice, self.return_position_ids is True when using FlashAttention. When using FlashAttention, we should
# NOT create attention_mask as it causes FlashAttention to ignore position_ids and compute wrong cu_seq_lens
# from the all-1s mask.
if self.return_position_ids:
if "seq_lengths" in examples[0]:
position_ids = self.get_position_ids_from_packed_seq_lengths(
@ -197,6 +192,13 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
)
else:
position_ids = [torch.arange(len(ids)) for ids in input_ids]
else:
if "seq_lengths" in examples[0]:
logger.warning(
"The input examples contain `seq_lengths` but `return_position_ids` is set to `False`. "
"`seq_lengths` will be ignored."
)
attention_mask = [torch.ones_like(input_ids) for input_ids in input_ids]
if "labels" in examples[0]:
labels = [torch.tensor(example["labels"]) for example in examples]
else:
@ -206,48 +208,48 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
if "assistant_masks" in examples[0]:
assistant_masks = [torch.tensor(example["assistant_masks"]) for example in examples]
# Pad
# If padding_free, flatten everything into a single sequence
output = {}
if self.padding_free:
output["input_ids"] = torch.cat(input_ids, dim=0).unsqueeze(0)
if not has_packed_position_ids:
output["attention_mask"] = torch.cat(attention_mask, dim=0).unsqueeze(0)
input_ids = [torch.cat(input_ids, dim=0)]
if self.return_position_ids:
output["position_ids"] = torch.cat(position_ids, dim=0).unsqueeze(0)
output["labels"] = torch.cat(labels, dim=0).unsqueeze(0)
position_ids = [torch.cat(position_ids, dim=0)]
else:
attention_mask = [torch.cat(attention_mask, dim=0)]
labels = [torch.cat(labels, dim=0)]
if self.completion_only_loss and "completion_mask" in examples[0]:
completion_mask = torch.cat(completion_mask, dim=0).unsqueeze(0)
output["labels"][completion_mask == 0] = -100
completion_mask = [torch.cat(completion_mask, dim=0)]
if "assistant_masks" in examples[0]:
assistant_masks = torch.cat(assistant_masks, dim=0).unsqueeze(0)
output["labels"][assistant_masks == 0] = -100
else:
output["input_ids"] = pad(
input_ids,
padding_value=self.pad_token_id,
padding_side="right",
pad_to_multiple_of=self.pad_to_multiple_of,
assistant_masks = [torch.cat(assistant_masks, dim=0)]
# Pad
output["input_ids"] = pad(
input_ids,
padding_value=self.pad_token_id,
padding_side="right",
pad_to_multiple_of=self.pad_to_multiple_of,
)
if self.return_position_ids:
output["position_ids"] = pad(
position_ids, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
)
else:
output["attention_mask"] = pad(
attention_mask, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
)
if self.return_position_ids:
output["position_ids"] = pad(
position_ids, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
)
output["labels"] = pad(
labels, padding_value=-100, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
output["labels"] = pad(
labels, padding_value=-100, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
)
if self.completion_only_loss and "completion_mask" in examples[0]:
completion_mask = pad(
completion_mask, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
)
if self.completion_only_loss and "completion_mask" in examples[0]:
completion_mask = pad(
completion_mask, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
)
output["labels"][completion_mask == 0] = -100 # mask everything that is not in the completion
if "assistant_masks" in examples[0]:
assistant_masks = pad(
assistant_masks, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
)
output["labels"][assistant_masks == 0] = -100
output["labels"][completion_mask == 0] = -100 # mask everything that is not in the completion
if "assistant_masks" in examples[0]:
assistant_masks = pad(
assistant_masks, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
)
output["labels"][assistant_masks == 0] = -100
return output
@staticmethod