|
|
|
@ -124,7 +124,11 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
|
|
|
|
|
that are no in the completion.
|
|
|
|
|
padding_free (`bool`, *optional*, defaults to `False`):
|
|
|
|
|
If set to `True`, the sequences will be flattened into a single sequence, and the position IDs will be
|
|
|
|
|
generated accordingly. The attention mask will be set to 1 for all tokens.
|
|
|
|
|
generated accordingly.
|
|
|
|
|
return_position_ids (`bool`, *optional*, defaults to `True`):
|
|
|
|
|
Whether to return position IDs. If `True`, position IDs are generated and returned. If `False`, attention
|
|
|
|
|
masks are generated and returned instead. Note that when using FlashAttention, this should be set to
|
|
|
|
|
`True`.
|
|
|
|
|
pad_to_multiple_of (`int` or `None`, *optional*, defaults to `None`):
|
|
|
|
|
If set, the sequences will be padded to a multiple of this value.
|
|
|
|
|
return_tensors (`str`, *optional*, defaults to `"pt"`):
|
|
|
|
@ -139,8 +143,6 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
|
|
|
|
|
>>> collator(examples)
|
|
|
|
|
{'input_ids': tensor([[ 1, 2, 3],
|
|
|
|
|
[ 4, 5, 0]]),
|
|
|
|
|
'attention_mask': tensor([[ 1, 1, 1],
|
|
|
|
|
[ 1, 1, 0]]),
|
|
|
|
|
'position_ids': tensor([[0, 1, 2],
|
|
|
|
|
[0, 1, 0]]),
|
|
|
|
|
'labels': tensor([[ 1, 2, 3],
|
|
|
|
@ -154,8 +156,6 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
|
|
|
|
|
>>> collator(examples)
|
|
|
|
|
{'input_ids': tensor([[ 1, 2, 3],
|
|
|
|
|
[ 4, 5, 0]]),
|
|
|
|
|
'attention_mask': tensor([[ 1, 1, 1],
|
|
|
|
|
[ 1, 1, 0]]),
|
|
|
|
|
'position_ids': tensor([[0, 1, 2],
|
|
|
|
|
[0, 1, 0]]),
|
|
|
|
|
'labels': tensor([[-100, 2, 3],
|
|
|
|
@ -182,14 +182,9 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
|
|
|
|
|
# Convert to tensor
|
|
|
|
|
input_ids = [torch.tensor(example["input_ids"]) for example in examples]
|
|
|
|
|
|
|
|
|
|
# Check if we have meaningful seq_lengths from packing (restarting sequences)
|
|
|
|
|
has_packed_position_ids = self.return_position_ids and "seq_lengths" in examples[0] and self.padding_free
|
|
|
|
|
|
|
|
|
|
# For packing with position_ids, we should NOT create attention_mask as it causes
|
|
|
|
|
# FlashAttention to ignore position_ids and compute wrong cu_seq_lens from the all-1s mask
|
|
|
|
|
if not has_packed_position_ids:
|
|
|
|
|
attention_mask = [torch.ones_like(input_ids) for input_ids in input_ids]
|
|
|
|
|
|
|
|
|
|
# In practice, self.return_position_ids is True when using FlashAttention. When using FlashAttention, we should
|
|
|
|
|
# NOT create attention_mask as it causes FlashAttention to ignore position_ids and compute wrong cu_seq_lens
|
|
|
|
|
# from the all-1s mask.
|
|
|
|
|
if self.return_position_ids:
|
|
|
|
|
if "seq_lengths" in examples[0]:
|
|
|
|
|
position_ids = self.get_position_ids_from_packed_seq_lengths(
|
|
|
|
@ -197,6 +192,13 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
position_ids = [torch.arange(len(ids)) for ids in input_ids]
|
|
|
|
|
else:
|
|
|
|
|
if "seq_lengths" in examples[0]:
|
|
|
|
|
logger.warning(
|
|
|
|
|
"The input examples contain `seq_lengths` but `return_position_ids` is set to `False`. "
|
|
|
|
|
"`seq_lengths` will be ignored."
|
|
|
|
|
)
|
|
|
|
|
attention_mask = [torch.ones_like(input_ids) for input_ids in input_ids]
|
|
|
|
|
if "labels" in examples[0]:
|
|
|
|
|
labels = [torch.tensor(example["labels"]) for example in examples]
|
|
|
|
|
else:
|
|
|
|
@ -206,48 +208,48 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
|
|
|
|
|
if "assistant_masks" in examples[0]:
|
|
|
|
|
assistant_masks = [torch.tensor(example["assistant_masks"]) for example in examples]
|
|
|
|
|
|
|
|
|
|
# Pad
|
|
|
|
|
# If padding_free, flatten everything into a single sequence
|
|
|
|
|
output = {}
|
|
|
|
|
if self.padding_free:
|
|
|
|
|
output["input_ids"] = torch.cat(input_ids, dim=0).unsqueeze(0)
|
|
|
|
|
if not has_packed_position_ids:
|
|
|
|
|
output["attention_mask"] = torch.cat(attention_mask, dim=0).unsqueeze(0)
|
|
|
|
|
input_ids = [torch.cat(input_ids, dim=0)]
|
|
|
|
|
if self.return_position_ids:
|
|
|
|
|
output["position_ids"] = torch.cat(position_ids, dim=0).unsqueeze(0)
|
|
|
|
|
output["labels"] = torch.cat(labels, dim=0).unsqueeze(0)
|
|
|
|
|
position_ids = [torch.cat(position_ids, dim=0)]
|
|
|
|
|
else:
|
|
|
|
|
attention_mask = [torch.cat(attention_mask, dim=0)]
|
|
|
|
|
labels = [torch.cat(labels, dim=0)]
|
|
|
|
|
if self.completion_only_loss and "completion_mask" in examples[0]:
|
|
|
|
|
completion_mask = torch.cat(completion_mask, dim=0).unsqueeze(0)
|
|
|
|
|
output["labels"][completion_mask == 0] = -100
|
|
|
|
|
completion_mask = [torch.cat(completion_mask, dim=0)]
|
|
|
|
|
if "assistant_masks" in examples[0]:
|
|
|
|
|
assistant_masks = torch.cat(assistant_masks, dim=0).unsqueeze(0)
|
|
|
|
|
output["labels"][assistant_masks == 0] = -100
|
|
|
|
|
else:
|
|
|
|
|
output["input_ids"] = pad(
|
|
|
|
|
input_ids,
|
|
|
|
|
padding_value=self.pad_token_id,
|
|
|
|
|
padding_side="right",
|
|
|
|
|
pad_to_multiple_of=self.pad_to_multiple_of,
|
|
|
|
|
assistant_masks = [torch.cat(assistant_masks, dim=0)]
|
|
|
|
|
|
|
|
|
|
# Pad
|
|
|
|
|
output["input_ids"] = pad(
|
|
|
|
|
input_ids,
|
|
|
|
|
padding_value=self.pad_token_id,
|
|
|
|
|
padding_side="right",
|
|
|
|
|
pad_to_multiple_of=self.pad_to_multiple_of,
|
|
|
|
|
)
|
|
|
|
|
if self.return_position_ids:
|
|
|
|
|
output["position_ids"] = pad(
|
|
|
|
|
position_ids, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
output["attention_mask"] = pad(
|
|
|
|
|
attention_mask, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
|
|
|
|
|
)
|
|
|
|
|
if self.return_position_ids:
|
|
|
|
|
output["position_ids"] = pad(
|
|
|
|
|
position_ids, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
|
|
|
|
|
)
|
|
|
|
|
output["labels"] = pad(
|
|
|
|
|
labels, padding_value=-100, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
|
|
|
|
|
output["labels"] = pad(
|
|
|
|
|
labels, padding_value=-100, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
|
|
|
|
|
)
|
|
|
|
|
if self.completion_only_loss and "completion_mask" in examples[0]:
|
|
|
|
|
completion_mask = pad(
|
|
|
|
|
completion_mask, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
|
|
|
|
|
)
|
|
|
|
|
if self.completion_only_loss and "completion_mask" in examples[0]:
|
|
|
|
|
completion_mask = pad(
|
|
|
|
|
completion_mask, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
|
|
|
|
|
)
|
|
|
|
|
output["labels"][completion_mask == 0] = -100 # mask everything that is not in the completion
|
|
|
|
|
if "assistant_masks" in examples[0]:
|
|
|
|
|
assistant_masks = pad(
|
|
|
|
|
assistant_masks, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
|
|
|
|
|
)
|
|
|
|
|
output["labels"][assistant_masks == 0] = -100
|
|
|
|
|
output["labels"][completion_mask == 0] = -100 # mask everything that is not in the completion
|
|
|
|
|
if "assistant_masks" in examples[0]:
|
|
|
|
|
assistant_masks = pad(
|
|
|
|
|
assistant_masks, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
|
|
|
|
|
)
|
|
|
|
|
output["labels"][assistant_masks == 0] = -100
|
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|