[flex_attention] replace sliced BlockMask noop with helpful error (#164702)

Fixes part of #163314

After slicing BlockMask with `[]`, mask_mod was silently replaced with noop_mask. This caused silent incorrect results when users applied transformations to `sliced_mask.mask_mod`.

Replace noop with `_sliced_mask_mod_error` that raises RuntimeError with guidance to use `base_mask.mask_mod` instead.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164702
Approved by: https://github.com/drisspg, https://github.com/BoyuanFeng
This commit is contained in:
Amin Sedaghat
2025-10-20 03:46:16 +00:00
committed by PyTorch MergeBot
parent 602ace5eb4
commit 767199fd9b
2 changed files with 55 additions and 1 deletions

View File

@ -4995,6 +4995,28 @@ class TestBlockMask(InductorTestCase):
block_mask.full_kv_indices[:, :, q_index, :],
)
@supported_platform
def test_sliced_blockmask_mask_mod_error(self, device):
"""Test that sliced BlockMask raises helpful error when used with flex_attention"""
def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx
base_mask = create_block_mask(
causal_mask, B=1, H=1, Q_LEN=256, KV_LEN=256, device=device
)
sliced_mask = base_mask[:, :, 0]
q = torch.randn(1, 1, 1, 64, device=device)
k = torch.randn(1, 1, 256, 64, device=device)
v = torch.randn(1, 1, 256, 64, device=device)
compiled_fa = torch.compile(flex_attention)
with self.assertRaisesRegex(
RuntimeError, "Cannot use mask_mod from a sliced BlockMask"
):
compiled_fa(q, k, v, block_mask=sliced_mask)
@supported_platform
def test_block_mask_device_change(self, device):
device = torch.device(device)