mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[flex_attention] replace sliced BlockMask noop with helpful error (#164702)
Fixes part of #163314 After slicing BlockMask with `[]`, mask_mod was silently replaced with noop_mask. This caused silent incorrect results when users applied transformations to `sliced_mask.mask_mod`. Replace noop with `_sliced_mask_mod_error` that raises RuntimeError with guidance to use `base_mask.mask_mod` instead. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164702 Approved by: https://github.com/drisspg, https://github.com/BoyuanFeng
This commit is contained in:
committed by
PyTorch MergeBot
parent
602ace5eb4
commit
767199fd9b
@ -4995,6 +4995,28 @@ class TestBlockMask(InductorTestCase):
|
||||
block_mask.full_kv_indices[:, :, q_index, :],
|
||||
)
|
||||
|
||||
@supported_platform
|
||||
def test_sliced_blockmask_mask_mod_error(self, device):
|
||||
"""Test that sliced BlockMask raises helpful error when used with flex_attention"""
|
||||
|
||||
def causal_mask(b, h, q_idx, kv_idx):
|
||||
return q_idx >= kv_idx
|
||||
|
||||
base_mask = create_block_mask(
|
||||
causal_mask, B=1, H=1, Q_LEN=256, KV_LEN=256, device=device
|
||||
)
|
||||
sliced_mask = base_mask[:, :, 0]
|
||||
|
||||
q = torch.randn(1, 1, 1, 64, device=device)
|
||||
k = torch.randn(1, 1, 256, 64, device=device)
|
||||
v = torch.randn(1, 1, 256, 64, device=device)
|
||||
|
||||
compiled_fa = torch.compile(flex_attention)
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError, "Cannot use mask_mod from a sliced BlockMask"
|
||||
):
|
||||
compiled_fa(q, k, v, block_mask=sliced_mask)
|
||||
|
||||
@supported_platform
|
||||
def test_block_mask_device_change(self, device):
|
||||
device = torch.device(device)
|
||||
|
Reference in New Issue
Block a user