[flex_attention] replace sliced BlockMask noop with helpful error (#164702)

Fixes part of #163314

After slicing BlockMask with `[]`, mask_mod was silently replaced with noop_mask. This caused silent incorrect results when users applied transformations to `sliced_mask.mask_mod`.

Replace noop with `_sliced_mask_mod_error` that raises RuntimeError with guidance to use `base_mask.mask_mod` instead.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164702
Approved by: https://github.com/drisspg, https://github.com/BoyuanFeng
This commit is contained in:
Amin Sedaghat
2025-10-20 03:46:16 +00:00
committed by PyTorch MergeBot
parent 602ace5eb4
commit 767199fd9b
2 changed files with 55 additions and 1 deletions

View File

@ -357,6 +357,33 @@ def noop_mask(
return batch.new_ones(size=(), dtype=torch.bool, device=batch.device)
def _sliced_mask_mod_error(
batch: Tensor,
head: Tensor,
token_q: Tensor,
token_kv: Tensor,
) -> Tensor:
"""
Raises helpful error when using mask_mod from a sliced BlockMask.
After slicing a BlockMask, the mask_mod is reset and cannot be used directly.
Users must reassign mask_mod from the original (unsliced) BlockMask.
"""
raise RuntimeError(
"Cannot use mask_mod from a sliced BlockMask. "
"When you slice a BlockMask using [], the mask_mod attribute is reset. "
"You must set it from the original BlockMask's mask_mod."
"\n\nIncorrect usage:"
"\n base_mask = create_block_mask(my_mask_fn, ...)"
"\n sliced_mask = base_mask[:, :, block_idx]"
"\n sliced_mask.mask_mod = apply_offset(sliced_mask.mask_mod, offset) # WRONG!"
"\n\nCorrect usage:"
"\n base_mask = create_block_mask(my_mask_fn, ...)"
"\n sliced_mask = base_mask[:, :, block_idx]"
"\n sliced_mask.mask_mod = apply_offset(base_mask.mask_mod, offset) # Use base_mask!"
)
_DEFAULT_SPARSE_BLOCK_SIZE = 128
_LARGE_SPARSE_BLOCK_SIZE = 1 << 30
@ -710,7 +737,7 @@ class BlockMask:
new_full_kv_num_blocks,
new_full_kv_indices,
BLOCK_SIZE=self.BLOCK_SIZE,
mask_mod=None,
mask_mod=_sliced_mask_mod_error,
seq_lengths=self.seq_lengths,
compute_q_blocks=self.q_indices is not None,
)
@ -1414,6 +1441,11 @@ def flex_attention(
if block_mask is None:
block_mask = _create_empty_block_mask(query, key)
# If BlockMask was sliced, its mask_mod is intentionally replaced with an error-raising stub.
# This guard ensures we surface the intended error message before any shape-based checks.
if getattr(block_mask, "mask_mod", None) is _sliced_mask_mod_error:
raise RuntimeError("Cannot use mask_mod from a sliced BlockMask")
if (
block_mask.BLOCK_SIZE[0] == _LARGE_SPARSE_BLOCK_SIZE
and block_mask.BLOCK_SIZE[1] == _LARGE_SPARSE_BLOCK_SIZE