Fix Flash Attention x Padding-Free loss (#4170)

This commit is contained in:
Quentin Gallouédec
2025-09-30 12:01:29 -06:00
committed by GitHub
parent 70e2017dbc
commit ebb8899f5d
2 changed files with 71 additions and 107 deletions

View File

@ -64,9 +64,9 @@ class TestDataCollatorForLanguageModeling(TrlTestCase):
result = self.collator(examples)
self.assertEqual(set(result.keys()), {"input_ids", "attention_mask", "labels"})
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 0]]))
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]]))
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2], [0, 1, 0]]))
torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3], [4, 5, -100]]))
def test_completion_mask(self):
@ -79,9 +79,9 @@ class TestDataCollatorForLanguageModeling(TrlTestCase):
result = self.collator(examples)
self.assertEqual(set(result.keys()), {"input_ids", "attention_mask", "labels"})
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 0]]))
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]]))
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2], [0, 1, 0]]))
torch.testing.assert_close(result["labels"], torch.tensor([[-100, 2, 3], [-100, 5, -100]]))
def test_completion_only_loss_disabled(self):
@ -95,9 +95,9 @@ class TestDataCollatorForLanguageModeling(TrlTestCase):
result = collator(examples)
# Labels should not be masked when completion_only_loss=False
self.assertEqual(set(result.keys()), {"input_ids", "attention_mask", "labels"})
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 0]]))
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]]))
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2], [0, 1, 0]]))
torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3], [4, 5, -100]]))
def test_padding_free_mode(self):
@ -107,72 +107,42 @@ class TestDataCollatorForLanguageModeling(TrlTestCase):
result = collator(examples)
self.assertEqual(set(result.keys()), {"input_ids", "position_ids", "labels"})
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4, 5]]))
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1, 1, 1]]))
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0, 1]]))
torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3, 4, 5]]))
torch.testing.assert_close(result["labels"], torch.tensor([[-100, 2, 3, -100, 5]]))
def test_padding_free_with_completion_mask(self):
"""Test padding-free mode with completion masks."""
collator = DataCollatorForLanguageModeling(pad_token_id=0, padding_free=True)
examples = [
{"input_ids": [1, 2, 3], "completion_mask": [0, 1, 1]},
{"input_ids": [1, 2, 3], "completion_mask": [0, 0, 1]},
{"input_ids": [4, 5], "completion_mask": [1, 1]},
]
result = collator(examples)
self.assertEqual(set(result.keys()), {"input_ids", "position_ids", "labels"})
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4, 5]]))
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1, 1, 1]]))
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0, 1]]))
torch.testing.assert_close(result["labels"], torch.tensor([[-100, 2, 3, 4, 5]]))
torch.testing.assert_close(result["labels"], torch.tensor([[-100, -100, 3, -100, 5]]))
def test_packing_drops_attention_mask_for_flash_attention(self):
def test_packing(self):
"""Test that when using packing with position_ids, attention_mask is dropped with fa2."""
collator = DataCollatorForLanguageModeling(pad_token_id=0, padding_free=True, return_position_ids=True)
collator = DataCollatorForLanguageModeling(pad_token_id=0, padding_free=True)
# Simulate packed sequences with position_ids that restart (typical of BFD packing)
examples = [
{
"input_ids": [1, 2, 3, 4, 5, 6, 7, 8], # Packed: [1,2,3] + [4,5] + [6,7,8]
"seq_lengths": [3, 2, 3],
}
{"input_ids": [1, 2, 3, 4, 5, 6], "seq_lengths": [3, 3]},
{"input_ids": [7, 8, 9, 10, 11], "seq_lengths": [4, 1]},
]
result = collator(examples)
# Verify that attention_mask is NOT present - this allows FlashAttention to use position_ids
self.assertNotIn("attention_mask", result, "attention_mask should be dropped for packing with position_ids")
# Verify essential keys are present
self.assertIn("input_ids", result)
self.assertIn("position_ids", result)
self.assertIn("labels", result)
# Verify the data is correctly processed
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]]))
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0, 1, 0, 1, 2]]))
torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]]))
def test_padding_free_without_position_ids_keeps_attention_mask(self):
"""
Test that padding_free mode without explicit position_ids still creates attention_mask.
"""
collator = DataCollatorForLanguageModeling(pad_token_id=0, padding_free=True, return_position_ids=True)
# Examples without position_ids (not packed)
examples = [{"input_ids": [1, 2, 3, 4, 5]}]
result = collator(examples)
# Should still have attention_mask since no packed position_ids
self.assertIn("attention_mask", result, "attention_mask should be present when no packed position_ids")
self.assertIn("position_ids", result)
self.assertIn("input_ids", result)
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4, 5]]))
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1, 1, 1]]))
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 3, 4]]))
self.assertEqual(set(result.keys()), {"input_ids", "position_ids", "labels"})
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]]))
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0, 1, 2, 0, 1, 2, 3, 0]]))
torch.testing.assert_close(result["labels"], torch.tensor([[-100, 2, 3, -100, 5, 6, -100, 8, 9, 10, -100]]))
def test_pad_to_multiple_of(self):
"""Test padding to multiple of specified value."""
@ -181,9 +151,9 @@ class TestDataCollatorForLanguageModeling(TrlTestCase):
result = collator(examples)
self.assertEqual(set(result.keys()), {"input_ids", "attention_mask", "labels"})
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 0], [4, 5, 0, 0]]))
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1, 0], [1, 1, 0, 0]]))
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0], [0, 1, 0, 0]]))
torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3, -100], [4, 5, -100, -100]]))
def test_pad_to_multiple_of_and_padding_free(self):
@ -193,21 +163,21 @@ class TestDataCollatorForLanguageModeling(TrlTestCase):
result = collator(examples)
self.assertEqual(set(result.keys()), {"input_ids", "position_ids", "labels"})
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4, 5, 0, 0, 0]]))
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1, 1, 1, 0, 0, 0]]))
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0, 1, 0, 0, 0]]))
torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3, 4, 5, -100, -100, -100]]))
torch.testing.assert_close(result["labels"], torch.tensor([[-100, 2, 3, -100, 5, -100, -100, -100]]))
def test_custom_position_ids(self):
"""Test handling of custom position IDs in examples."""
def test_custom_position_ids_but_no_padding_free(self):
"""Test that custom position_ids are ignored if padding_free is False."""
self.collator = DataCollatorForLanguageModeling(pad_token_id=0)
examples = [{"input_ids": [1, 2, 3], "seq_lengths": [1, 2]}, {"input_ids": [4, 5], "seq_lengths": [2]}]
result = self.collator(examples)
self.assertEqual(set(result.keys()), {"input_ids", "attention_mask", "labels"})
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 0]]))
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]]))
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 0, 1], [0, 1, 0]]))
torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3], [4, 5, -100]]))
def test_single_example(self):
@ -217,9 +187,9 @@ class TestDataCollatorForLanguageModeling(TrlTestCase):
result = self.collator(examples)
self.assertEqual(set(result.keys()), {"input_ids", "attention_mask", "labels"})
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4]]))
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1, 1]]))
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 3]]))
torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3, 4]]))
def test_different_pad_token_id(self):
@ -229,9 +199,9 @@ class TestDataCollatorForLanguageModeling(TrlTestCase):
result = collator(examples)
self.assertEqual(set(result.keys()), {"input_ids", "attention_mask", "labels"})
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 999]]))
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]]))
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2], [0, 1, 0]]))
torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3], [4, 5, -100]]))
def test_assistant_masks(self):
@ -246,7 +216,6 @@ class TestDataCollatorForLanguageModeling(TrlTestCase):
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 0]]))
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]]))
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2], [0, 1, 0]]))
torch.testing.assert_close(result["labels"], torch.tensor([[-100, 2, 3], [-100, 5, -100]]))
def test_single_example_single_doc(self):

View File

@ -68,6 +68,14 @@ logger = logging.get_logger(__name__)
TListOrMapping = TypeVar("TListOrMapping", list, Mapping)
FLASH_ATTENTION_VARIANTS = {
"flash_attention_2",
"flash_attention_3",
"kernels-community/flash-attn",
"kernels-community/vllm-flash-attn3",
"kernels-community/flash-attn3",
}
def remove_none_values(example: TListOrMapping) -> TListOrMapping:
"""
@ -115,11 +123,13 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
completion. If `"assistant_masks"` are present, they are used to set the labels to `-100` for tokens that are not
in the assistant part of the sequence. The collator returns a dictionary containing the following keys:
- `"input_ids"`: Tensor of input IDs, padded to the maximum length of the batch.
- `"attention_mask"`: Tensor of attention mask, padded to the maximum length of the batch.
- `"position_ids"`: Tensor of position IDs, padded to the maximum length of the batch.
- `"labels"`: Tensor of labels, padded to the maximum length of the batch. If `completion_only_loss` is set to
`True`, tokens that are not in the completion are set to -100. If `assistant_masks` are present, tokens that are
not in the assistant part of the sequence are set to -100.
not in the assistant part of the sequence are set to -100. If `padding_free` is set to `False`, the following key
is also returned:
- `"attention_mask"`: Tensor of attention masks, padded to the maximum length of the batch.
If `padding_free` is set to `True`, the following key is also returned:
- `"position_ids"`: Tensor of position IDs, padded to the maximum length of the batch.
Args:
pad_token_id (`int`):
@ -129,7 +139,7 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
that are no in the completion.
padding_free (`bool`, *optional*, defaults to `False`):
If set to `True`, the sequences will be flattened into a single sequence, and the position IDs will be
generated accordingly.
generated accordingly and returned instead of the attention mask.
pad_to_multiple_of (`int`, *optional*):
If set, the sequences will be padded to a multiple of this value.
return_tensors (`str`, *optional*, defaults to `"pt"`):
@ -146,8 +156,6 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
[ 4, 5, 0]]),
'attention_mask': tensor([[ 1, 1, 1],
[ 1, 1, 0]]),
'position_ids': tensor([[0, 1, 2],
[0, 1, 0]]),
'labels': tensor([[ 1, 2, 3],
[ 4, 5, -100]])}
@ -161,8 +169,6 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
[ 4, 5, 0]]),
'attention_mask': tensor([[ 1, 1, 1],
[ 1, 1, 0]]),
'position_ids': tensor([[0, 1, 2],
[0, 1, 0]]),
'labels': tensor([[-100, 2, 3],
[-100, 5, -100]])}
@ -170,7 +176,6 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
>>> collator = DataCollatorForLanguageModeling(pad_token_id=0, padding_free=True)
>>> collator(examples)
{'input_ids': tensor([[ 1, 2, 3, 4, 5]]),
'attention_mask': tensor([[1, 1, 1, 1, 1]]),
'position_ids': tensor([[0, 1, 2, 0, 1]]),
'labels': tensor([[1, 2, 3, 4, 5]])}
```
@ -179,33 +184,28 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
pad_token_id: int
completion_only_loss: bool = True
padding_free: bool = False
return_position_ids: bool = True
pad_to_multiple_of: Optional[int] = None
return_tensors: str = "pt"
def torch_call(self, examples: list[dict[str, Any]]) -> dict[str, Any]:
# Convert to tensor
input_ids = [torch.tensor(example["input_ids"]) for example in examples]
if "labels" in examples[0]:
labels = [torch.tensor(example["labels"]) for example in examples]
else:
labels = [torch.tensor(example["input_ids"]) for example in examples]
# Check if we have meaningful seq_lengths from packing (restarting sequences)
has_packed_position_ids = self.return_position_ids and "seq_lengths" in examples[0] and self.padding_free
# For packing with position_ids, we should NOT create attention_mask as it causes
# FlashAttention to ignore position_ids and compute wrong cu_seq_lens from the all-1s mask
if not has_packed_position_ids:
attention_mask = [torch.ones_like(ids) for ids in input_ids]
if self.return_position_ids:
# For padding-free, we should NOT create attention_mask as it causes FlashAttention to ignore position_ids and
# compute wrong cu_seq_lens from the all-1s mask
if self.padding_free:
if "seq_lengths" in examples[0]:
position_ids = self.get_position_ids_from_packed_seq_lengths(
[example["seq_lengths"] for example in examples]
)
else:
position_ids = [torch.arange(len(ids)) for ids in input_ids]
if "labels" in examples[0]:
labels = [torch.tensor(example["labels"]) for example in examples]
else:
labels = [torch.tensor(example["input_ids"]) for example in examples]
attention_mask = [torch.ones_like(ids) for ids in input_ids]
if self.completion_only_loss and "completion_mask" in examples[0]:
completion_mask = [torch.tensor(example["completion_mask"]) for example in examples]
if "assistant_masks" in examples[0]:
@ -215,9 +215,8 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
output = {}
if self.padding_free:
input_ids = [torch.cat(input_ids, dim=0)]
if self.return_position_ids:
position_ids = [torch.cat(position_ids, dim=0)]
labels = [torch.cat(labels, dim=0)]
position_ids = [torch.cat(position_ids, dim=0)]
if self.completion_only_loss and "completion_mask" in examples[0]:
completion_mask = [torch.cat(completion_mask, dim=0)]
if "assistant_masks" in examples[0]:
@ -230,18 +229,18 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
padding_side="right",
pad_to_multiple_of=self.pad_to_multiple_of,
)
if not has_packed_position_ids:
attention_mask = [torch.ones_like(input_ids) for input_ids in input_ids]
output["attention_mask"] = pad(
attention_mask, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
)
if self.return_position_ids:
output["position_ids"] = pad(
position_ids, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
)
output["labels"] = pad(
labels, padding_value=-100, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
)
if self.padding_free:
output["position_ids"] = pad(
position_ids, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
)
output["labels"][output["position_ids"] == 0] = -100
else:
output["attention_mask"] = pad(
attention_mask, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
)
if self.completion_only_loss and "completion_mask" in examples[0]:
completion_mask = pad(
completion_mask, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
@ -718,11 +717,7 @@ class SFTTrainer(BaseTrainer):
# BFD packing requires padding-free mode; otherwise, the collator outputs padded attention masks, causing
# FlashAttention to ignore position_ids and recompute them incorrectly from the padded attention mask.
self.padding_free = args.padding_free or (args.packing and args.packing_strategy == "bfd")
use_flash_attention = model.config._attn_implementation in [
"flash_attention_2",
"flash_attention_3",
"kernels-community/vllm-flash-attn3",
]
use_flash_attention = model.config._attn_implementation in FLASH_ATTENTION_VARIANTS
if self.padding_free:
if data_collator is not None:
raise ValueError("Passing a custom data collator is not supported when using padding-free.")
@ -733,13 +728,15 @@ class SFTTrainer(BaseTrainer):
)
if not use_flash_attention:
logger.warning(
"Padding-free training is enabled, but the attention implementation is not set to "
"'flash_attention_2'. Padding-free training flattens batches into a single sequence, and "
"'flash_attention_2' is the only known attention mechanism that reliably supports this. Using "
"other implementations may lead to unexpected behavior. To ensure compatibility, set "
"`attn_implementation='flash_attention_2'` in the model configuration, or verify that your "
"attention mechanism can handle flattened sequences."
"Padding-free training is enabled, but the attention implementation is not set to a supported "
"flash attention variant. Padding-free training flattens batches into a single sequence, and only "
"the following implementations are known to reliably support this: "
f"{', '.join(sorted(FLASH_ATTENTION_VARIANTS))}. Using other implementations may lead to "
"unexpected behavior. To ensure compatibility, set `attn_implementation` in the model "
"configuration to one of these supported options or verify that your attention mechanism can "
"handle flattened sequences."
)
if args.per_device_train_batch_size == 1 and not args.packing:
logger.warning(
"You are using a per_device_train_batch_size of 1 with padding-free training. Using a batch size "
@ -777,8 +774,6 @@ class SFTTrainer(BaseTrainer):
pad_token_id=pad_token_id,
completion_only_loss=self.completion_only_loss,
padding_free=self.padding_free,
# Using position_ids without flash_attn hurts the training
return_position_ids=use_flash_attention,
pad_to_multiple_of=args.pad_to_multiple_of,
)
elif data_collator is None and self._is_vision_dataset:
@ -792,12 +787,12 @@ class SFTTrainer(BaseTrainer):
if args.packing and args.packing_strategy == "bfd" and not use_flash_attention:
logger.warning(
"You are using packing, but the attention implementation is not set to 'flash_attention_2' or "
"'kernels-community/vllm-flash-attn3'. Packing flattens batches into a single sequence, and Flash "
"Attention is the only known attention mechanisms that reliably support this. Using other "
"implementations may lead to cross-contamination between batches. To avoid this, either disable "
"packing by setting `packing=False`, or set `attn_implementation='flash_attention_2'` or "
"`attn_implementation='kernels-community/vllm-flash-attn3'` in the model configuration."
"You are using packing, but the attention implementation is not set to a supported flash attention "
"variant. Packing gathers multiple samples into a single sequence, and only the following "
f"implementations are known to reliably support this: {', '.join(sorted(FLASH_ATTENTION_VARIANTS))}. "
"Using other implementations may lead to cross-contamination between samples. To avoid this, either "
"disable packing by setting `packing=False`, or set `attn_implementation` in the model configuration "
"to one of these supported options."
)
if args.assistant_only_loss and not is_conversational(dataset_sample):
raise ValueError(