[BC Breaking] Remove flex + njt code paths (#161734)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161734
Approved by: https://github.com/jbschlosser
This commit is contained in:
drisspg
2025-09-15 12:05:41 -07:00
committed by PyTorch MergeBot
parent dac6a4bf6c
commit d08cabe314
5 changed files with 0 additions and 477 deletions

View File

@ -30,9 +30,6 @@
.. autofunction:: create_mask
```
```{eval-rst}
.. autofunction:: create_nested_block_mask
```
```{eval-rst}
.. autofunction:: and_masks
```
```{eval-rst}

View File

@ -26,7 +26,6 @@ from torch.nested._internal.nested_tensor import (
NestedTensor,
ViewNestedFromBuffer,
)
from torch.nn.attention.flex_attention import create_nested_block_mask, flex_attention
from torch.testing._internal.common_cuda import (
PLATFORM_SUPPORTS_FUSED_ATTENTION,
SM70OrLater,
@ -36,7 +35,6 @@ from torch.testing._internal.common_cuda import (
from torch.testing._internal.common_device_type import (
dtypes,
dtypesIfCUDA,
flex_attention_supported_platform,
instantiate_device_type_tests,
onlyCPU,
onlyCUDA,
@ -60,7 +58,6 @@ from torch.testing._internal.common_utils import (
parametrize,
run_tests,
serialTest,
skipIfRocm,
skipIfSlowGradcheckEnv,
skipIfTorchDynamo,
subtest,
@ -7285,124 +7282,6 @@ torch.cuda.synchronize()
return query, key, value
@unittest.skip(
"Temporarily skip - nested tensor backward pass broken after return-max-scores commit"
)
@onlyCUDA
@flex_attention_supported_platform
@dtypes(torch.float32)
# non-contiguous with holes not supported yet
@decorateIf(unittest.skip, lambda params: params["noncontig_with_holes"])
@parametrize("noncontig_with_holes", [False, True])
@parametrize("cross_attention", [False, True])
@skipIfRocm
def test_flex_attention(self, device, dtype, noncontig_with_holes, cross_attention):
query, key, value = self._rand_qkv(
device, dtype, noncontig_with_holes, q_and_kv_match=(not cross_attention)
)
# Run FlexAttention with a causal mask
def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx
if cross_attention:
block_mask = create_nested_block_mask(
causal_mask, 1, 1, query, key, _compile=True
)
else:
block_mask = create_nested_block_mask(
causal_mask, 1, 1, query, _compile=True
)
out_flex = flex_attention(query, key, value, block_mask=block_mask)
grad_out = torch.randn_like(out_flex)
grads_flex = torch.autograd.grad(
out_flex, inputs=(query, key, value), grad_outputs=(grad_out,)
)
flex_outs = [out_flex, *grads_flex]
# Run FlexAttention with a score_mod that represents causal attention
def causal_score_mod(score, b, h, q_idx, kv_idx):
return torch.where(q_idx >= kv_idx, score, float("-inf"))
out_flex2 = flex_attention(query, key, value, score_mod=causal_score_mod)
grads_flex2 = torch.autograd.grad(
out_flex2, inputs=(query, key, value), grad_outputs=(grad_out,)
)
flex_outs2 = [out_flex2, *grads_flex2]
# Run causal SDPA for comparison
out_sdpa = F.scaled_dot_product_attention(query, key, value, is_causal=True)
grads_sdpa = torch.autograd.grad(
out_sdpa, inputs=(query, key, value), grad_outputs=(grad_out,)
)
sdpa_outs = [out_sdpa, *grads_sdpa]
# Compare flex vs. SDPA output and grads
for flex, flex2, sdpa in zip(flex_outs, flex_outs2, sdpa_outs):
self.assertTrue(flex.is_nested and flex2.is_nested and sdpa.is_nested)
self.assertEqual(flex, sdpa, atol=1e-2, rtol=1e-2)
self.assertEqual(flex2, sdpa, atol=1e-2, rtol=1e-2)
@onlyCUDA
@flex_attention_supported_platform
@dtypes(torch.float32)
def test_flex_attention_converts_stacked_seq_indices(self, device, dtype):
# This test verifies that a score_mod function written to operate within
# NJT sequence index space, such as a lookup table, works correctly. This
# validates that FlexAttention properly converts indices within the
# "stacked sequence" space used for NJT -> sequence-relative indices.
query, key, value = self._rand_qkv(device, dtype)
# Test with score_mod
score_mod_table = torch.randn(query._max_seqlen, device=device, dtype=dtype)
def my_score_mod(score, b, h, q_idx, kv_idx):
return score_mod_table[q_idx]
flex_attention(query, key, value, score_mod=my_score_mod)
# Test with batch-specific score_mod
batch_size = query.size(0)
batch_table = torch.randn(batch_size, device=device, dtype=dtype)
# Keep score the same for batch index == 0
batch_table[0].zero_()
def batch_specific_score_mod(score, b, h, q_idx, kv_idx):
return score + batch_table[b]
def identity_score_mod(score, b, h, q_idx, kv_idx):
return score
output = flex_attention(query, key, value, score_mod=batch_specific_score_mod)
output_identity = flex_attention(
query, key, value, score_mod=identity_score_mod
)
# Guard against a bug where the batch index passed to score_mod is always b == 0.
# Output would be equivalent to applying an identity score_mod.
# See https://github.com/pytorch/pytorch/issues/143788
self.assertFalse(torch.allclose(output._values, output_identity._values))
# Test with mask_mod
mask_mod_table = score_mod_table > 0.0
def my_mask_mod(b, h, q_idx, kv_idx):
return mask_mod_table[q_idx]
def my_mask_mod2(b, h, q_idx, kv_idx):
return mask_mod_table[q_idx] & (b == 0)
block_mask = create_nested_block_mask(my_mask_mod, 1, 1, query, _compile=True)
output = flex_attention(query, key, value, block_mask=block_mask)
block_mask2 = create_nested_block_mask(my_mask_mod2, 1, 1, query, _compile=True)
output2 = flex_attention(query, key, value, block_mask=block_mask2)
# Guard against a bug where the batch index passed to mask_mod is always b == 0.
# See https://github.com/pytorch/pytorch/issues/143788
self.assertFalse(torch.allclose(output._values, output2._values))
@dtypes(torch.float32)
def test_apply_(self, device, dtype):
nt = random_nt_from_dims(

View File

@ -505,13 +505,6 @@ def flex_attention_fake_impl(
):
return NotImplemented
# TODO: Figure out a better way to handle this for NJT than using sum()
if query.is_nested:
out = torch.empty_like(query, memory_format=torch.contiguous_format)
logsumexp = query.sum(dim=-1)
max_scores = query.max(dim=-1)[0]
return out, logsumexp, max_scores
v_head_dim = value.size(-1)
batch_size, num_heads, seq_len_q, _q_head_dim = query.shape
logsumexp = query.new_empty(batch_size, num_heads, seq_len_q, dtype=torch.float32)

View File

@ -2665,144 +2665,6 @@ def matmul_backward_default(func, *args, **kwargs):
return (grad_self, grad_other)
from torch._higher_order_ops.flex_attention import (
flex_attention as flex_attention_hop,
flex_attention_backward as flex_attention_backward_hop,
)
from torch.fx.graph_module import GraphModule
@flex_attention_hop.py_impl(NestedTensor) # type: ignore[misc]
def flex_njt(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
score_mod: Callable,
block_mask: Tuple,
scale: float,
kernel_options: Dict[str, Any],
score_mod_other_buffers: Tuple = (),
mask_mod_other_buffers: Tuple = (),
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert query.dim() == 4 and key.dim() == 4 and value.dim() == 4
# TODO: Support this if needed; determine if NJT buffers need be unwrapped as dense.
if any(
isinstance(buf, torch.Tensor) and buf.is_nested
for buf in score_mod_other_buffers + mask_mod_other_buffers
):
raise RuntimeError(
"flex_attention(): Nested tensor score_mod / mask_mod buffers are not "
"currently supported. Please file an issue if this is important to you."
)
# Always set them since 0 sized elements are not handled gracefully
kernel_options = {**kernel_options, "OUTPUT_MAX": True, "OUTPUT_LOGSUMEXP": True}
# need to pass dense tensor of shape (B, n_heads, sum(seq_len), D)
output = flex_attention_hop(
query.values().unsqueeze(0),
key.values().unsqueeze(0),
value.values().unsqueeze(0),
score_mod=score_mod,
block_mask=block_mask,
scale=scale,
kernel_options=kernel_options,
score_mod_other_buffers=score_mod_other_buffers,
mask_mod_other_buffers=mask_mod_other_buffers,
)
# wrap outputs as NJT
output_njt = torch.nested.nested_tensor_from_jagged(
output[0].transpose(1, 2).squeeze(0),
query._offsets, # type: ignore[attr-defined]
query._lengths, # type: ignore[attr-defined]
min_seqlen=query._maybe_min_seqlen, # type: ignore[attr-defined]
max_seqlen=query._maybe_max_seqlen, # type: ignore[attr-defined]
).transpose(1, 2)
logsumexp_njt = torch.nested.nested_tensor_from_jagged(
output[1].transpose(1, 2).squeeze(0),
query._offsets, # type: ignore[attr-defined]
query._lengths, # type: ignore[attr-defined]
min_seqlen=query._maybe_min_seqlen, # type: ignore[attr-defined]
max_seqlen=query._maybe_max_seqlen, # type: ignore[attr-defined]
).transpose(1, 2)
max_scores_njt = torch.nested.nested_tensor_from_jagged(
output[2].transpose(1, 2).squeeze(0),
query._offsets, # type: ignore[attr-defined]
query._lengths, # type: ignore[attr-defined]
min_seqlen=query._maybe_min_seqlen, # type: ignore[attr-defined]
max_seqlen=query._maybe_max_seqlen, # type: ignore[attr-defined]
).transpose(1, 2)
return (output_njt, logsumexp_njt, max_scores_njt)
@flex_attention_backward_hop.py_impl(NestedTensor) # type: ignore[misc]
def flex_njt_backward(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
out: torch.Tensor,
logsumexp: torch.Tensor,
grad_out: torch.Tensor,
grad_logsumexp: torch.Tensor,
fw_graph: Union[Callable, GraphModule],
joint_graph: GraphModule,
block_mask: Tuple,
scale: float,
kernel_options: Dict[str, Any],
score_mod_other_buffers: Tuple = (),
mask_mod_other_buffers: Tuple = (),
) -> Tuple[
torch.Tensor, torch.Tensor, torch.Tensor, Tuple[Optional[torch.Tensor], ...]
]:
output = flex_attention_backward_hop(
query.values().unsqueeze(0),
key.values().unsqueeze(0),
value.values().unsqueeze(0),
out=out.values().unsqueeze(0),
logsumexp=logsumexp.values().unsqueeze(0),
grad_out=grad_out.values().unsqueeze(0),
grad_logsumexp=grad_logsumexp.values().unsqueeze(0),
fw_graph=fw_graph,
joint_graph=joint_graph,
block_mask=block_mask,
scale=scale,
kernel_options=kernel_options,
score_mod_other_buffers=score_mod_other_buffers,
mask_mod_other_buffers=mask_mod_other_buffers,
)
# wrap grads as NJTs
dense_q_grad, dense_k_grad, dense_v_grad, score_mod_other_buffer_grads = output
njt_q_grad = torch.nested.nested_tensor_from_jagged(
dense_q_grad.transpose(1, 2).squeeze(0),
query._offsets, # type: ignore[attr-defined]
query._lengths, # type: ignore[attr-defined]
min_seqlen=query._maybe_min_seqlen, # type: ignore[attr-defined]
max_seqlen=query._maybe_max_seqlen, # type: ignore[attr-defined]
).transpose(1, 2)
njt_k_grad = torch.nested.nested_tensor_from_jagged(
dense_k_grad.transpose(1, 2).squeeze(0),
key._offsets, # type: ignore[attr-defined]
key._lengths, # type: ignore[attr-defined]
min_seqlen=key._maybe_min_seqlen, # type: ignore[attr-defined]
max_seqlen=key._maybe_max_seqlen, # type: ignore[attr-defined]
).transpose(1, 2)
njt_v_grad = torch.nested.nested_tensor_from_jagged(
dense_v_grad.transpose(1, 2).squeeze(0),
value._offsets, # type: ignore[attr-defined]
value._lengths, # type: ignore[attr-defined]
min_seqlen=value._maybe_min_seqlen, # type: ignore[attr-defined]
max_seqlen=value._maybe_max_seqlen, # type: ignore[attr-defined]
).transpose(1, 2)
return (njt_q_grad, njt_k_grad, njt_v_grad, score_mod_other_buffer_grads)
# Make the dummy available on the C++ side.
@register_jagged_func(torch.ops.aten._nested_get_jagged_dummy.default, "self: any")
def _nested_get_jagged_dummy(func, *args, **kwargs):

View File

@ -74,7 +74,6 @@ __all__ = [
"FlexKernelOptions",
"create_block_mask",
"create_mask",
"create_nested_block_mask",
"or_masks",
"and_masks",
"noop_mask",
@ -1111,179 +1110,6 @@ def _create_empty_block_mask(query: Tensor, key: Tensor) -> BlockMask:
)
def _nested_mod_func_adapter(
orig_mod_func: Union[_score_mod_signature, _mask_mod_signature],
q_nt: torch.Tensor,
kv_nt: torch.Tensor,
is_score_mod: bool,
) -> Union[_score_mod_signature, _mask_mod_signature]:
r"""Adapter to convert a score_mod / mask_mod to be NJT-compatible. The given mod func
should be written as if operating over a single sequence at a item. This adapter will
handle conversion from indices operating over a "stacked sequence" of length ``sum(S)``
for sequence length ``S`` in the NJT to "sequence relative" indices in range ``[0, S)``.
Args:
orig_mod_func (Callable): Function to modify attention scores. It takes four or five
arguments, depending on whether a mask_mod or score_mod func is passed.
q_nt (torch.Tensor): Jagged layout nested tensor (NJT) that defines the sequence length
structure for query.
kv_nt (torch.Tensor): Jagged layout nested tensor (NJT) that defines the sequence length
structure for key / value.
is_score_mod (bool): Indicates whether the mod function is a score_mod.
Returns:
nt_score_mod: An NJT-compatible version of orig_score_mod
"""
# Used to convert indices within the "stacked" sequence (range [0, sum(*)))
# to "sequence local" indices (range [0, S) for each S).
def _build_seq_idx(offsets, total_length):
range_tensor = torch.arange(
total_length, device=offsets.device, dtype=torch.int32
)
# Use searchsorted to find the index for each position
# NB: This assumes offsets[0] to offsets[-1] spans the packed dim of values.
# If we ever loosen this restriction, this logic will need to be updated.
seq_idx = torch.searchsorted(offsets, range_tensor, right=True) - 1
return seq_idx
q_offsets = q_nt._offsets # type: ignore[attr-defined]
kv_offsets = kv_nt._offsets # type: ignore[attr-defined]
q_seq_idx = _build_seq_idx(q_offsets, q_nt._values.shape[q_nt._ragged_idx - 1]) # type: ignore[attr-defined]
if q_nt is kv_nt:
kv_seq_idx = q_seq_idx
else:
# cross attention case
kv_seq_idx = _build_seq_idx(
kv_offsets,
kv_nt._values.shape[kv_nt._ragged_idx - 1], # type: ignore[attr-defined]
)
# Converts q_idx / kv_idx from [0, total_length) -> [0, S), where S refers
# to the sequence length for each sequence in the NJT, for use in given
# score_mod. This allows the user to write a score_mod as if it were
# operating on a single sequence and the "stacked sequence" is split
# automatically into individual sequences for them.
if is_score_mod:
def nt_score_mod(score, b, h, q_idx, kv_idx):
b_nested = q_seq_idx[q_idx]
q_nested = q_idx - q_offsets[q_seq_idx[q_idx]]
kv_nested = kv_idx - kv_offsets[kv_seq_idx[kv_idx]]
is_same_sequence = q_seq_idx[q_idx] == kv_seq_idx[kv_idx]
return torch.where(
is_same_sequence,
orig_mod_func(score, b_nested, h, q_nested, kv_nested), # type: ignore[call-arg]
# don't allow inter-sequence attention
float("-inf"),
)
return nt_score_mod
else:
def nt_mask_mod(b, h, q_idx, kv_idx):
b_nested = q_seq_idx[q_idx]
q_nested = q_idx - q_offsets[q_seq_idx[q_idx]]
kv_nested = kv_idx - kv_offsets[kv_seq_idx[kv_idx]]
# don't allow inter-sequence attention
is_same_sequence = q_seq_idx[q_idx] == kv_seq_idx[kv_idx]
return orig_mod_func(b_nested, h, q_nested, kv_nested) & is_same_sequence # type: ignore[call-arg]
return nt_mask_mod
def create_nested_block_mask(
mask_mod: _mask_mod_signature,
B: Optional[int],
H: Optional[int],
q_nt: torch.Tensor,
kv_nt: Optional[torch.Tensor] = None,
BLOCK_SIZE: Union[int, tuple[int, int]] = _DEFAULT_SPARSE_BLOCK_SIZE,
_compile=False,
) -> BlockMask:
r"""This function creates a nested tensor compatible block mask tuple from a mask_mod
function. The returned BlockMask will be on the device specified by the input nested tensor.
Args:
mask_mod (Callable): mask_mod function. This is a callable that defines the
masking pattern for the attention mechanism. It takes four arguments:
b (batch size), h (number of heads), q_idx (query index), and kv_idx (key/value index).
It should return a boolean tensor indicating which attention connections are allowed
(True) or masked out (False).
B (int): Batch size.
H (int): Number of query heads.
q_nt (torch.Tensor): Jagged layout nested tensor (NJT) that defines the sequence length
structure for query. The block mask will be constructed to operate on a "stacked
sequence" of length ``sum(S)`` for sequence length ``S`` from the NJT.
kv_nt (torch.Tensor): Jagged layout nested tensor (NJT) that defines the sequence length
structure for key / value, allowing for cross attention. The block mask will be
constructed to operate on a "stacked sequence" of length ``sum(S)`` for sequence
length ``S`` from the NJT. If this is None, ``q_nt`` is used to define the structure
for key / value as well. Default: None
BLOCK_SIZE (int or tuple[int, int]): Block size for the block mask. If a single int is
provided it is used for both query and key/value.
Returns:
BlockMask: A BlockMask object that contains the block mask information.
Example Usage:
.. code-block:: python
# shape (B, num_heads, seq_len*, D) where seq_len* varies across the batch
query = torch.nested.nested_tensor(..., layout=torch.jagged)
key = torch.nested.nested_tensor(..., layout=torch.jagged)
value = torch.nested.nested_tensor(..., layout=torch.jagged)
def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx
block_mask = create_nested_block_mask(
causal_mask, 1, 1, query, _compile=True
)
output = flex_attention(query, key, value, block_mask=block_mask)
.. code-block:: python
# shape (B, num_heads, seq_len*, D) where seq_len* varies across the batch
query = torch.nested.nested_tensor(..., layout=torch.jagged)
key = torch.nested.nested_tensor(..., layout=torch.jagged)
value = torch.nested.nested_tensor(..., layout=torch.jagged)
def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx
# cross attention case: pass both query and key/value NJTs
block_mask = create_nested_block_mask(
causal_mask, 1, 1, query, key, _compile=True
)
output = flex_attention(query, key, value, block_mask=block_mask)
"""
# use same structure for kv as for q by default
if kv_nt is None:
kv_nt = q_nt
if q_nt.device != kv_nt.device:
raise ValueError(
"create_nested_block_mask(): Expected q_nt and kv_nt to be on the same device"
)
return create_block_mask(
_nested_mod_func_adapter(mask_mod, q_nt, kv_nt, is_score_mod=False), # type: ignore[arg-type]
B,
H,
q_nt._values.shape[q_nt._ragged_idx - 1], # type: ignore[attr-defined]
kv_nt._values.shape[kv_nt._ragged_idx - 1], # type: ignore[attr-defined]
device=q_nt.device, # type: ignore[arg-type]
# compile is important so we don't materialize a mask_tensor of
# shape (1, 1, total_seqlen, total_seqlen)
BLOCK_SIZE=BLOCK_SIZE,
_compile=_compile,
)
def _apply_kernel_options(
query: Tensor,
key: Tensor,
@ -1359,25 +1185,6 @@ def _validate_device(query: Tensor, key: Tensor, value: Tensor):
)
def _validate_nestedness(query: Tensor, key: Tensor, value: Tensor):
# Currently, inputs can only be all nested or no nested.
if query.is_nested != key.is_nested or key.is_nested != value.is_nested:
raise ValueError(
"FlexAttention does not support mixed nested tensor / non-nested tensor inputs. "
"Please file an issue requesting this if it is important to you."
)
if (
(query.is_nested and query._lengths is not None) # type: ignore[attr-defined]
or (key.is_nested and key._lengths is not None) # type: ignore[attr-defined]
or (value.is_nested and value._lengths is not None) # type: ignore[attr-defined]
):
raise ValueError(
"FlexAttention does not support nested tensors that are non-contiguous with holes. "
"Please file an issue requesting this if it is important to you."
)
def _enforce_mem_layouts(
query: Tensor, key: Tensor, value: Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
@ -1517,7 +1324,6 @@ def flex_attention(
_validate_sdpa_input(query, key, value)
_validate_embed_dim(query, key, value)
_validate_device(query, key, value)
_validate_nestedness(query, key, value)
query, key, value = _enforce_mem_layouts(query, key, value)
if query.dim() != 4 or key.dim() != 4 or value.dim() != 4:
raise NotImplementedError("NYI: query, key, and value must be 4D tensors")
@ -1552,14 +1358,6 @@ def flex_attention(
if score_mod is None:
score_mod = _identity
elif query.is_nested:
# use same NJT if the ragged structures for sequence lengths match between q and kv
kv = (
query
if query.size(query._ragged_idx) == key.size(query._ragged_idx) # type: ignore[attr-defined]
else key
)
score_mod = _nested_mod_func_adapter(score_mod, query, kv, is_score_mod=True) # type: ignore[assignment]
if block_mask is None:
block_mask = _create_empty_block_mask(query, key)
@ -1570,12 +1368,6 @@ def flex_attention(
):
# This corresponds to the case where we essentially have a "no-op" block mask.
pass
elif query.is_nested:
if block_mask.shape[-2] != query._values.size(query._ragged_idx - 1): # type: ignore[attr-defined]
raise RuntimeError(
f"block_mask of shape {block_mask.shape} is not compatible with nested tensor input "
f"with total sequence length of {query._values.size(query._ragged_idx - 1)}" # type: ignore[attr-defined]
)
else:
block_mask_q_len = block_mask.shape[-2]
block_mask_kv_len = block_mask.shape[-1]