mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
⚡ Fix Flash Attention x Padding-Free loss (#4170)
This commit is contained in:
committed by
GitHub
parent
70e2017dbc
commit
ebb8899f5d
@ -64,9 +64,9 @@ class TestDataCollatorForLanguageModeling(TrlTestCase):
|
||||
|
||||
result = self.collator(examples)
|
||||
|
||||
self.assertEqual(set(result.keys()), {"input_ids", "attention_mask", "labels"})
|
||||
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 0]]))
|
||||
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]]))
|
||||
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2], [0, 1, 0]]))
|
||||
torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3], [4, 5, -100]]))
|
||||
|
||||
def test_completion_mask(self):
|
||||
@ -79,9 +79,9 @@ class TestDataCollatorForLanguageModeling(TrlTestCase):
|
||||
|
||||
result = self.collator(examples)
|
||||
|
||||
self.assertEqual(set(result.keys()), {"input_ids", "attention_mask", "labels"})
|
||||
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 0]]))
|
||||
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]]))
|
||||
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2], [0, 1, 0]]))
|
||||
torch.testing.assert_close(result["labels"], torch.tensor([[-100, 2, 3], [-100, 5, -100]]))
|
||||
|
||||
def test_completion_only_loss_disabled(self):
|
||||
@ -95,9 +95,9 @@ class TestDataCollatorForLanguageModeling(TrlTestCase):
|
||||
result = collator(examples)
|
||||
|
||||
# Labels should not be masked when completion_only_loss=False
|
||||
self.assertEqual(set(result.keys()), {"input_ids", "attention_mask", "labels"})
|
||||
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 0]]))
|
||||
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]]))
|
||||
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2], [0, 1, 0]]))
|
||||
torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3], [4, 5, -100]]))
|
||||
|
||||
def test_padding_free_mode(self):
|
||||
@ -107,72 +107,42 @@ class TestDataCollatorForLanguageModeling(TrlTestCase):
|
||||
|
||||
result = collator(examples)
|
||||
|
||||
self.assertEqual(set(result.keys()), {"input_ids", "position_ids", "labels"})
|
||||
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4, 5]]))
|
||||
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1, 1, 1]]))
|
||||
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0, 1]]))
|
||||
torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3, 4, 5]]))
|
||||
torch.testing.assert_close(result["labels"], torch.tensor([[-100, 2, 3, -100, 5]]))
|
||||
|
||||
def test_padding_free_with_completion_mask(self):
|
||||
"""Test padding-free mode with completion masks."""
|
||||
collator = DataCollatorForLanguageModeling(pad_token_id=0, padding_free=True)
|
||||
examples = [
|
||||
{"input_ids": [1, 2, 3], "completion_mask": [0, 1, 1]},
|
||||
{"input_ids": [1, 2, 3], "completion_mask": [0, 0, 1]},
|
||||
{"input_ids": [4, 5], "completion_mask": [1, 1]},
|
||||
]
|
||||
|
||||
result = collator(examples)
|
||||
|
||||
self.assertEqual(set(result.keys()), {"input_ids", "position_ids", "labels"})
|
||||
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4, 5]]))
|
||||
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1, 1, 1]]))
|
||||
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0, 1]]))
|
||||
torch.testing.assert_close(result["labels"], torch.tensor([[-100, 2, 3, 4, 5]]))
|
||||
torch.testing.assert_close(result["labels"], torch.tensor([[-100, -100, 3, -100, 5]]))
|
||||
|
||||
def test_packing_drops_attention_mask_for_flash_attention(self):
|
||||
def test_packing(self):
|
||||
"""Test that when using packing with position_ids, attention_mask is dropped with fa2."""
|
||||
collator = DataCollatorForLanguageModeling(pad_token_id=0, padding_free=True, return_position_ids=True)
|
||||
collator = DataCollatorForLanguageModeling(pad_token_id=0, padding_free=True)
|
||||
|
||||
# Simulate packed sequences with position_ids that restart (typical of BFD packing)
|
||||
examples = [
|
||||
{
|
||||
"input_ids": [1, 2, 3, 4, 5, 6, 7, 8], # Packed: [1,2,3] + [4,5] + [6,7,8]
|
||||
"seq_lengths": [3, 2, 3],
|
||||
}
|
||||
{"input_ids": [1, 2, 3, 4, 5, 6], "seq_lengths": [3, 3]},
|
||||
{"input_ids": [7, 8, 9, 10, 11], "seq_lengths": [4, 1]},
|
||||
]
|
||||
|
||||
result = collator(examples)
|
||||
|
||||
# Verify that attention_mask is NOT present - this allows FlashAttention to use position_ids
|
||||
self.assertNotIn("attention_mask", result, "attention_mask should be dropped for packing with position_ids")
|
||||
|
||||
# Verify essential keys are present
|
||||
self.assertIn("input_ids", result)
|
||||
self.assertIn("position_ids", result)
|
||||
self.assertIn("labels", result)
|
||||
|
||||
# Verify the data is correctly processed
|
||||
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]]))
|
||||
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0, 1, 0, 1, 2]]))
|
||||
torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]]))
|
||||
|
||||
def test_padding_free_without_position_ids_keeps_attention_mask(self):
|
||||
"""
|
||||
Test that padding_free mode without explicit position_ids still creates attention_mask.
|
||||
"""
|
||||
collator = DataCollatorForLanguageModeling(pad_token_id=0, padding_free=True, return_position_ids=True)
|
||||
|
||||
# Examples without position_ids (not packed)
|
||||
examples = [{"input_ids": [1, 2, 3, 4, 5]}]
|
||||
|
||||
result = collator(examples)
|
||||
|
||||
# Should still have attention_mask since no packed position_ids
|
||||
self.assertIn("attention_mask", result, "attention_mask should be present when no packed position_ids")
|
||||
self.assertIn("position_ids", result)
|
||||
self.assertIn("input_ids", result)
|
||||
|
||||
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4, 5]]))
|
||||
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1, 1, 1]]))
|
||||
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 3, 4]]))
|
||||
self.assertEqual(set(result.keys()), {"input_ids", "position_ids", "labels"})
|
||||
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]]))
|
||||
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0, 1, 2, 0, 1, 2, 3, 0]]))
|
||||
torch.testing.assert_close(result["labels"], torch.tensor([[-100, 2, 3, -100, 5, 6, -100, 8, 9, 10, -100]]))
|
||||
|
||||
def test_pad_to_multiple_of(self):
|
||||
"""Test padding to multiple of specified value."""
|
||||
@ -181,9 +151,9 @@ class TestDataCollatorForLanguageModeling(TrlTestCase):
|
||||
|
||||
result = collator(examples)
|
||||
|
||||
self.assertEqual(set(result.keys()), {"input_ids", "attention_mask", "labels"})
|
||||
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 0], [4, 5, 0, 0]]))
|
||||
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1, 0], [1, 1, 0, 0]]))
|
||||
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0], [0, 1, 0, 0]]))
|
||||
torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3, -100], [4, 5, -100, -100]]))
|
||||
|
||||
def test_pad_to_multiple_of_and_padding_free(self):
|
||||
@ -193,21 +163,21 @@ class TestDataCollatorForLanguageModeling(TrlTestCase):
|
||||
|
||||
result = collator(examples)
|
||||
|
||||
self.assertEqual(set(result.keys()), {"input_ids", "position_ids", "labels"})
|
||||
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4, 5, 0, 0, 0]]))
|
||||
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1, 1, 1, 0, 0, 0]]))
|
||||
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0, 1, 0, 0, 0]]))
|
||||
torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3, 4, 5, -100, -100, -100]]))
|
||||
torch.testing.assert_close(result["labels"], torch.tensor([[-100, 2, 3, -100, 5, -100, -100, -100]]))
|
||||
|
||||
def test_custom_position_ids(self):
|
||||
"""Test handling of custom position IDs in examples."""
|
||||
def test_custom_position_ids_but_no_padding_free(self):
|
||||
"""Test that custom position_ids are ignored if padding_free is False."""
|
||||
self.collator = DataCollatorForLanguageModeling(pad_token_id=0)
|
||||
examples = [{"input_ids": [1, 2, 3], "seq_lengths": [1, 2]}, {"input_ids": [4, 5], "seq_lengths": [2]}]
|
||||
|
||||
result = self.collator(examples)
|
||||
|
||||
self.assertEqual(set(result.keys()), {"input_ids", "attention_mask", "labels"})
|
||||
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 0]]))
|
||||
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]]))
|
||||
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 0, 1], [0, 1, 0]]))
|
||||
torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3], [4, 5, -100]]))
|
||||
|
||||
def test_single_example(self):
|
||||
@ -217,9 +187,9 @@ class TestDataCollatorForLanguageModeling(TrlTestCase):
|
||||
|
||||
result = self.collator(examples)
|
||||
|
||||
self.assertEqual(set(result.keys()), {"input_ids", "attention_mask", "labels"})
|
||||
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4]]))
|
||||
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1, 1]]))
|
||||
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 3]]))
|
||||
torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3, 4]]))
|
||||
|
||||
def test_different_pad_token_id(self):
|
||||
@ -229,9 +199,9 @@ class TestDataCollatorForLanguageModeling(TrlTestCase):
|
||||
|
||||
result = collator(examples)
|
||||
|
||||
self.assertEqual(set(result.keys()), {"input_ids", "attention_mask", "labels"})
|
||||
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 999]]))
|
||||
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]]))
|
||||
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2], [0, 1, 0]]))
|
||||
torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3], [4, 5, -100]]))
|
||||
|
||||
def test_assistant_masks(self):
|
||||
@ -246,7 +216,6 @@ class TestDataCollatorForLanguageModeling(TrlTestCase):
|
||||
|
||||
torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 0]]))
|
||||
torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]]))
|
||||
torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2], [0, 1, 0]]))
|
||||
torch.testing.assert_close(result["labels"], torch.tensor([[-100, 2, 3], [-100, 5, -100]]))
|
||||
|
||||
def test_single_example_single_doc(self):
|
||||
|
@ -68,6 +68,14 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
TListOrMapping = TypeVar("TListOrMapping", list, Mapping)
|
||||
|
||||
FLASH_ATTENTION_VARIANTS = {
|
||||
"flash_attention_2",
|
||||
"flash_attention_3",
|
||||
"kernels-community/flash-attn",
|
||||
"kernels-community/vllm-flash-attn3",
|
||||
"kernels-community/flash-attn3",
|
||||
}
|
||||
|
||||
|
||||
def remove_none_values(example: TListOrMapping) -> TListOrMapping:
|
||||
"""
|
||||
@ -115,11 +123,13 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
|
||||
completion. If `"assistant_masks"` are present, they are used to set the labels to `-100` for tokens that are not
|
||||
in the assistant part of the sequence. The collator returns a dictionary containing the following keys:
|
||||
- `"input_ids"`: Tensor of input IDs, padded to the maximum length of the batch.
|
||||
- `"attention_mask"`: Tensor of attention mask, padded to the maximum length of the batch.
|
||||
- `"position_ids"`: Tensor of position IDs, padded to the maximum length of the batch.
|
||||
- `"labels"`: Tensor of labels, padded to the maximum length of the batch. If `completion_only_loss` is set to
|
||||
`True`, tokens that are not in the completion are set to -100. If `assistant_masks` are present, tokens that are
|
||||
not in the assistant part of the sequence are set to -100.
|
||||
not in the assistant part of the sequence are set to -100. If `padding_free` is set to `False`, the following key
|
||||
is also returned:
|
||||
- `"attention_mask"`: Tensor of attention masks, padded to the maximum length of the batch.
|
||||
If `padding_free` is set to `True`, the following key is also returned:
|
||||
- `"position_ids"`: Tensor of position IDs, padded to the maximum length of the batch.
|
||||
|
||||
Args:
|
||||
pad_token_id (`int`):
|
||||
@ -129,7 +139,7 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
|
||||
that are no in the completion.
|
||||
padding_free (`bool`, *optional*, defaults to `False`):
|
||||
If set to `True`, the sequences will be flattened into a single sequence, and the position IDs will be
|
||||
generated accordingly.
|
||||
generated accordingly and returned instead of the attention mask.
|
||||
pad_to_multiple_of (`int`, *optional*):
|
||||
If set, the sequences will be padded to a multiple of this value.
|
||||
return_tensors (`str`, *optional*, defaults to `"pt"`):
|
||||
@ -146,8 +156,6 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
|
||||
[ 4, 5, 0]]),
|
||||
'attention_mask': tensor([[ 1, 1, 1],
|
||||
[ 1, 1, 0]]),
|
||||
'position_ids': tensor([[0, 1, 2],
|
||||
[0, 1, 0]]),
|
||||
'labels': tensor([[ 1, 2, 3],
|
||||
[ 4, 5, -100]])}
|
||||
|
||||
@ -161,8 +169,6 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
|
||||
[ 4, 5, 0]]),
|
||||
'attention_mask': tensor([[ 1, 1, 1],
|
||||
[ 1, 1, 0]]),
|
||||
'position_ids': tensor([[0, 1, 2],
|
||||
[0, 1, 0]]),
|
||||
'labels': tensor([[-100, 2, 3],
|
||||
[-100, 5, -100]])}
|
||||
|
||||
@ -170,7 +176,6 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
|
||||
>>> collator = DataCollatorForLanguageModeling(pad_token_id=0, padding_free=True)
|
||||
>>> collator(examples)
|
||||
{'input_ids': tensor([[ 1, 2, 3, 4, 5]]),
|
||||
'attention_mask': tensor([[1, 1, 1, 1, 1]]),
|
||||
'position_ids': tensor([[0, 1, 2, 0, 1]]),
|
||||
'labels': tensor([[1, 2, 3, 4, 5]])}
|
||||
```
|
||||
@ -179,33 +184,28 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
|
||||
pad_token_id: int
|
||||
completion_only_loss: bool = True
|
||||
padding_free: bool = False
|
||||
return_position_ids: bool = True
|
||||
pad_to_multiple_of: Optional[int] = None
|
||||
return_tensors: str = "pt"
|
||||
|
||||
def torch_call(self, examples: list[dict[str, Any]]) -> dict[str, Any]:
|
||||
# Convert to tensor
|
||||
input_ids = [torch.tensor(example["input_ids"]) for example in examples]
|
||||
if "labels" in examples[0]:
|
||||
labels = [torch.tensor(example["labels"]) for example in examples]
|
||||
else:
|
||||
labels = [torch.tensor(example["input_ids"]) for example in examples]
|
||||
|
||||
# Check if we have meaningful seq_lengths from packing (restarting sequences)
|
||||
has_packed_position_ids = self.return_position_ids and "seq_lengths" in examples[0] and self.padding_free
|
||||
|
||||
# For packing with position_ids, we should NOT create attention_mask as it causes
|
||||
# FlashAttention to ignore position_ids and compute wrong cu_seq_lens from the all-1s mask
|
||||
if not has_packed_position_ids:
|
||||
attention_mask = [torch.ones_like(ids) for ids in input_ids]
|
||||
|
||||
if self.return_position_ids:
|
||||
# For padding-free, we should NOT create attention_mask as it causes FlashAttention to ignore position_ids and
|
||||
# compute wrong cu_seq_lens from the all-1s mask
|
||||
if self.padding_free:
|
||||
if "seq_lengths" in examples[0]:
|
||||
position_ids = self.get_position_ids_from_packed_seq_lengths(
|
||||
[example["seq_lengths"] for example in examples]
|
||||
)
|
||||
else:
|
||||
position_ids = [torch.arange(len(ids)) for ids in input_ids]
|
||||
if "labels" in examples[0]:
|
||||
labels = [torch.tensor(example["labels"]) for example in examples]
|
||||
else:
|
||||
labels = [torch.tensor(example["input_ids"]) for example in examples]
|
||||
attention_mask = [torch.ones_like(ids) for ids in input_ids]
|
||||
if self.completion_only_loss and "completion_mask" in examples[0]:
|
||||
completion_mask = [torch.tensor(example["completion_mask"]) for example in examples]
|
||||
if "assistant_masks" in examples[0]:
|
||||
@ -215,9 +215,8 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
|
||||
output = {}
|
||||
if self.padding_free:
|
||||
input_ids = [torch.cat(input_ids, dim=0)]
|
||||
if self.return_position_ids:
|
||||
position_ids = [torch.cat(position_ids, dim=0)]
|
||||
labels = [torch.cat(labels, dim=0)]
|
||||
position_ids = [torch.cat(position_ids, dim=0)]
|
||||
if self.completion_only_loss and "completion_mask" in examples[0]:
|
||||
completion_mask = [torch.cat(completion_mask, dim=0)]
|
||||
if "assistant_masks" in examples[0]:
|
||||
@ -230,18 +229,18 @@ class DataCollatorForLanguageModeling(DataCollatorMixin):
|
||||
padding_side="right",
|
||||
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||
)
|
||||
if not has_packed_position_ids:
|
||||
attention_mask = [torch.ones_like(input_ids) for input_ids in input_ids]
|
||||
output["attention_mask"] = pad(
|
||||
attention_mask, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
|
||||
)
|
||||
if self.return_position_ids:
|
||||
output["position_ids"] = pad(
|
||||
position_ids, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
|
||||
)
|
||||
output["labels"] = pad(
|
||||
labels, padding_value=-100, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
|
||||
)
|
||||
if self.padding_free:
|
||||
output["position_ids"] = pad(
|
||||
position_ids, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
|
||||
)
|
||||
output["labels"][output["position_ids"] == 0] = -100
|
||||
else:
|
||||
output["attention_mask"] = pad(
|
||||
attention_mask, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
|
||||
)
|
||||
if self.completion_only_loss and "completion_mask" in examples[0]:
|
||||
completion_mask = pad(
|
||||
completion_mask, padding_value=0, padding_side="right", pad_to_multiple_of=self.pad_to_multiple_of
|
||||
@ -718,11 +717,7 @@ class SFTTrainer(BaseTrainer):
|
||||
# BFD packing requires padding-free mode; otherwise, the collator outputs padded attention masks, causing
|
||||
# FlashAttention to ignore position_ids and recompute them incorrectly from the padded attention mask.
|
||||
self.padding_free = args.padding_free or (args.packing and args.packing_strategy == "bfd")
|
||||
use_flash_attention = model.config._attn_implementation in [
|
||||
"flash_attention_2",
|
||||
"flash_attention_3",
|
||||
"kernels-community/vllm-flash-attn3",
|
||||
]
|
||||
use_flash_attention = model.config._attn_implementation in FLASH_ATTENTION_VARIANTS
|
||||
if self.padding_free:
|
||||
if data_collator is not None:
|
||||
raise ValueError("Passing a custom data collator is not supported when using padding-free.")
|
||||
@ -733,13 +728,15 @@ class SFTTrainer(BaseTrainer):
|
||||
)
|
||||
if not use_flash_attention:
|
||||
logger.warning(
|
||||
"Padding-free training is enabled, but the attention implementation is not set to "
|
||||
"'flash_attention_2'. Padding-free training flattens batches into a single sequence, and "
|
||||
"'flash_attention_2' is the only known attention mechanism that reliably supports this. Using "
|
||||
"other implementations may lead to unexpected behavior. To ensure compatibility, set "
|
||||
"`attn_implementation='flash_attention_2'` in the model configuration, or verify that your "
|
||||
"attention mechanism can handle flattened sequences."
|
||||
"Padding-free training is enabled, but the attention implementation is not set to a supported "
|
||||
"flash attention variant. Padding-free training flattens batches into a single sequence, and only "
|
||||
"the following implementations are known to reliably support this: "
|
||||
f"{', '.join(sorted(FLASH_ATTENTION_VARIANTS))}. Using other implementations may lead to "
|
||||
"unexpected behavior. To ensure compatibility, set `attn_implementation` in the model "
|
||||
"configuration to one of these supported options or verify that your attention mechanism can "
|
||||
"handle flattened sequences."
|
||||
)
|
||||
|
||||
if args.per_device_train_batch_size == 1 and not args.packing:
|
||||
logger.warning(
|
||||
"You are using a per_device_train_batch_size of 1 with padding-free training. Using a batch size "
|
||||
@ -777,8 +774,6 @@ class SFTTrainer(BaseTrainer):
|
||||
pad_token_id=pad_token_id,
|
||||
completion_only_loss=self.completion_only_loss,
|
||||
padding_free=self.padding_free,
|
||||
# Using position_ids without flash_attn hurts the training
|
||||
return_position_ids=use_flash_attention,
|
||||
pad_to_multiple_of=args.pad_to_multiple_of,
|
||||
)
|
||||
elif data_collator is None and self._is_vision_dataset:
|
||||
@ -792,12 +787,12 @@ class SFTTrainer(BaseTrainer):
|
||||
|
||||
if args.packing and args.packing_strategy == "bfd" and not use_flash_attention:
|
||||
logger.warning(
|
||||
"You are using packing, but the attention implementation is not set to 'flash_attention_2' or "
|
||||
"'kernels-community/vllm-flash-attn3'. Packing flattens batches into a single sequence, and Flash "
|
||||
"Attention is the only known attention mechanisms that reliably support this. Using other "
|
||||
"implementations may lead to cross-contamination between batches. To avoid this, either disable "
|
||||
"packing by setting `packing=False`, or set `attn_implementation='flash_attention_2'` or "
|
||||
"`attn_implementation='kernels-community/vllm-flash-attn3'` in the model configuration."
|
||||
"You are using packing, but the attention implementation is not set to a supported flash attention "
|
||||
"variant. Packing gathers multiple samples into a single sequence, and only the following "
|
||||
f"implementations are known to reliably support this: {', '.join(sorted(FLASH_ATTENTION_VARIANTS))}. "
|
||||
"Using other implementations may lead to cross-contamination between samples. To avoid this, either "
|
||||
"disable packing by setting `packing=False`, or set `attn_implementation` in the model configuration "
|
||||
"to one of these supported options."
|
||||
)
|
||||
if args.assistant_only_loss and not is_conversational(dataset_sample):
|
||||
raise ValueError(
|
||||
|
Reference in New Issue
Block a user