mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[flex_attention] replace sliced BlockMask noop with helpful error (#164702)
Fixes part of #163314 After slicing BlockMask with `[]`, mask_mod was silently replaced with noop_mask. This caused silent incorrect results when users applied transformations to `sliced_mask.mask_mod`. Replace noop with `_sliced_mask_mod_error` that raises RuntimeError with guidance to use `base_mask.mask_mod` instead. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164702 Approved by: https://github.com/drisspg, https://github.com/BoyuanFeng
This commit is contained in:
committed by
PyTorch MergeBot
parent
602ace5eb4
commit
767199fd9b
@ -4995,6 +4995,28 @@ class TestBlockMask(InductorTestCase):
|
||||
block_mask.full_kv_indices[:, :, q_index, :],
|
||||
)
|
||||
|
||||
@supported_platform
|
||||
def test_sliced_blockmask_mask_mod_error(self, device):
|
||||
"""Test that sliced BlockMask raises helpful error when used with flex_attention"""
|
||||
|
||||
def causal_mask(b, h, q_idx, kv_idx):
|
||||
return q_idx >= kv_idx
|
||||
|
||||
base_mask = create_block_mask(
|
||||
causal_mask, B=1, H=1, Q_LEN=256, KV_LEN=256, device=device
|
||||
)
|
||||
sliced_mask = base_mask[:, :, 0]
|
||||
|
||||
q = torch.randn(1, 1, 1, 64, device=device)
|
||||
k = torch.randn(1, 1, 256, 64, device=device)
|
||||
v = torch.randn(1, 1, 256, 64, device=device)
|
||||
|
||||
compiled_fa = torch.compile(flex_attention)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "Cannot use mask_mod from a sliced BlockMask"
|
||||
):
|
||||
compiled_fa(q, k, v, block_mask=sliced_mask)
|
||||
|
||||
@supported_platform
|
||||
def test_block_mask_device_change(self, device):
|
||||
device = torch.device(device)
|
||||
|
@ -357,6 +357,33 @@ def noop_mask(
|
||||
return batch.new_ones(size=(), dtype=torch.bool, device=batch.device)
|
||||
|
||||
|
||||
def _sliced_mask_mod_error(
|
||||
batch: Tensor,
|
||||
head: Tensor,
|
||||
token_q: Tensor,
|
||||
token_kv: Tensor,
|
||||
) -> Tensor:
|
||||
"""
|
||||
Raises helpful error when using mask_mod from a sliced BlockMask.
|
||||
|
||||
After slicing a BlockMask, the mask_mod is reset and cannot be used directly.
|
||||
Users must reassign mask_mod from the original (unsliced) BlockMask.
|
||||
"""
|
||||
raise RuntimeError(
|
||||
"Cannot use mask_mod from a sliced BlockMask. "
|
||||
"When you slice a BlockMask using [], the mask_mod attribute is reset. "
|
||||
"You must set it from the original BlockMask's mask_mod."
|
||||
"\n\nIncorrect usage:"
|
||||
"\n base_mask = create_block_mask(my_mask_fn, ...)"
|
||||
"\n sliced_mask = base_mask[:, :, block_idx]"
|
||||
"\n sliced_mask.mask_mod = apply_offset(sliced_mask.mask_mod, offset) # WRONG!"
|
||||
"\n\nCorrect usage:"
|
||||
"\n base_mask = create_block_mask(my_mask_fn, ...)"
|
||||
"\n sliced_mask = base_mask[:, :, block_idx]"
|
||||
"\n sliced_mask.mask_mod = apply_offset(base_mask.mask_mod, offset) # Use base_mask!"
|
||||
)
|
||||
|
||||
|
||||
_DEFAULT_SPARSE_BLOCK_SIZE = 128
|
||||
_LARGE_SPARSE_BLOCK_SIZE = 1 << 30
|
||||
|
||||
@ -710,7 +737,7 @@ class BlockMask:
|
||||
new_full_kv_num_blocks,
|
||||
new_full_kv_indices,
|
||||
BLOCK_SIZE=self.BLOCK_SIZE,
|
||||
mask_mod=None,
|
||||
mask_mod=_sliced_mask_mod_error,
|
||||
seq_lengths=self.seq_lengths,
|
||||
compute_q_blocks=self.q_indices is not None,
|
||||
)
|
||||
@ -1414,6 +1441,11 @@ def flex_attention(
|
||||
if block_mask is None:
|
||||
block_mask = _create_empty_block_mask(query, key)
|
||||
|
||||
# If BlockMask was sliced, its mask_mod is intentionally replaced with an error-raising stub.
|
||||
# This guard ensures we surface the intended error message before any shape-based checks.
|
||||
if getattr(block_mask, "mask_mod", None) is _sliced_mask_mod_error:
|
||||
raise RuntimeError("Cannot use mask_mod from a sliced BlockMask")
|
||||
|
||||
if (
|
||||
block_mask.BLOCK_SIZE[0] == _LARGE_SPARSE_BLOCK_SIZE
|
||||
and block_mask.BLOCK_SIZE[1] == _LARGE_SPARSE_BLOCK_SIZE
|
||||
|
Reference in New Issue
Block a user