mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[BC Breaking] Remove flex + njt code paths (#161734)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161734 Approved by: https://github.com/jbschlosser
This commit is contained in:
committed by
PyTorch MergeBot
parent
dac6a4bf6c
commit
d08cabe314
@ -30,9 +30,6 @@
|
||||
.. autofunction:: create_mask
|
||||
```
|
||||
```{eval-rst}
|
||||
.. autofunction:: create_nested_block_mask
|
||||
```
|
||||
```{eval-rst}
|
||||
.. autofunction:: and_masks
|
||||
```
|
||||
```{eval-rst}
|
||||
|
@ -26,7 +26,6 @@ from torch.nested._internal.nested_tensor import (
|
||||
NestedTensor,
|
||||
ViewNestedFromBuffer,
|
||||
)
|
||||
from torch.nn.attention.flex_attention import create_nested_block_mask, flex_attention
|
||||
from torch.testing._internal.common_cuda import (
|
||||
PLATFORM_SUPPORTS_FUSED_ATTENTION,
|
||||
SM70OrLater,
|
||||
@ -36,7 +35,6 @@ from torch.testing._internal.common_cuda import (
|
||||
from torch.testing._internal.common_device_type import (
|
||||
dtypes,
|
||||
dtypesIfCUDA,
|
||||
flex_attention_supported_platform,
|
||||
instantiate_device_type_tests,
|
||||
onlyCPU,
|
||||
onlyCUDA,
|
||||
@ -60,7 +58,6 @@ from torch.testing._internal.common_utils import (
|
||||
parametrize,
|
||||
run_tests,
|
||||
serialTest,
|
||||
skipIfRocm,
|
||||
skipIfSlowGradcheckEnv,
|
||||
skipIfTorchDynamo,
|
||||
subtest,
|
||||
@ -7285,124 +7282,6 @@ torch.cuda.synchronize()
|
||||
|
||||
return query, key, value
|
||||
|
||||
@unittest.skip(
|
||||
"Temporarily skip - nested tensor backward pass broken after return-max-scores commit"
|
||||
)
|
||||
@onlyCUDA
|
||||
@flex_attention_supported_platform
|
||||
@dtypes(torch.float32)
|
||||
# non-contiguous with holes not supported yet
|
||||
@decorateIf(unittest.skip, lambda params: params["noncontig_with_holes"])
|
||||
@parametrize("noncontig_with_holes", [False, True])
|
||||
@parametrize("cross_attention", [False, True])
|
||||
@skipIfRocm
|
||||
def test_flex_attention(self, device, dtype, noncontig_with_holes, cross_attention):
|
||||
query, key, value = self._rand_qkv(
|
||||
device, dtype, noncontig_with_holes, q_and_kv_match=(not cross_attention)
|
||||
)
|
||||
|
||||
# Run FlexAttention with a causal mask
|
||||
def causal_mask(b, h, q_idx, kv_idx):
|
||||
return q_idx >= kv_idx
|
||||
|
||||
if cross_attention:
|
||||
block_mask = create_nested_block_mask(
|
||||
causal_mask, 1, 1, query, key, _compile=True
|
||||
)
|
||||
else:
|
||||
block_mask = create_nested_block_mask(
|
||||
causal_mask, 1, 1, query, _compile=True
|
||||
)
|
||||
|
||||
out_flex = flex_attention(query, key, value, block_mask=block_mask)
|
||||
grad_out = torch.randn_like(out_flex)
|
||||
grads_flex = torch.autograd.grad(
|
||||
out_flex, inputs=(query, key, value), grad_outputs=(grad_out,)
|
||||
)
|
||||
flex_outs = [out_flex, *grads_flex]
|
||||
|
||||
# Run FlexAttention with a score_mod that represents causal attention
|
||||
def causal_score_mod(score, b, h, q_idx, kv_idx):
|
||||
return torch.where(q_idx >= kv_idx, score, float("-inf"))
|
||||
|
||||
out_flex2 = flex_attention(query, key, value, score_mod=causal_score_mod)
|
||||
grads_flex2 = torch.autograd.grad(
|
||||
out_flex2, inputs=(query, key, value), grad_outputs=(grad_out,)
|
||||
)
|
||||
flex_outs2 = [out_flex2, *grads_flex2]
|
||||
|
||||
# Run causal SDPA for comparison
|
||||
out_sdpa = F.scaled_dot_product_attention(query, key, value, is_causal=True)
|
||||
grads_sdpa = torch.autograd.grad(
|
||||
out_sdpa, inputs=(query, key, value), grad_outputs=(grad_out,)
|
||||
)
|
||||
sdpa_outs = [out_sdpa, *grads_sdpa]
|
||||
|
||||
# Compare flex vs. SDPA output and grads
|
||||
for flex, flex2, sdpa in zip(flex_outs, flex_outs2, sdpa_outs):
|
||||
self.assertTrue(flex.is_nested and flex2.is_nested and sdpa.is_nested)
|
||||
self.assertEqual(flex, sdpa, atol=1e-2, rtol=1e-2)
|
||||
self.assertEqual(flex2, sdpa, atol=1e-2, rtol=1e-2)
|
||||
|
||||
@onlyCUDA
|
||||
@flex_attention_supported_platform
|
||||
@dtypes(torch.float32)
|
||||
def test_flex_attention_converts_stacked_seq_indices(self, device, dtype):
|
||||
# This test verifies that a score_mod function written to operate within
|
||||
# NJT sequence index space, such as a lookup table, works correctly. This
|
||||
# validates that FlexAttention properly converts indices within the
|
||||
# "stacked sequence" space used for NJT -> sequence-relative indices.
|
||||
query, key, value = self._rand_qkv(device, dtype)
|
||||
|
||||
# Test with score_mod
|
||||
score_mod_table = torch.randn(query._max_seqlen, device=device, dtype=dtype)
|
||||
|
||||
def my_score_mod(score, b, h, q_idx, kv_idx):
|
||||
return score_mod_table[q_idx]
|
||||
|
||||
flex_attention(query, key, value, score_mod=my_score_mod)
|
||||
|
||||
# Test with batch-specific score_mod
|
||||
batch_size = query.size(0)
|
||||
batch_table = torch.randn(batch_size, device=device, dtype=dtype)
|
||||
# Keep score the same for batch index == 0
|
||||
batch_table[0].zero_()
|
||||
|
||||
def batch_specific_score_mod(score, b, h, q_idx, kv_idx):
|
||||
return score + batch_table[b]
|
||||
|
||||
def identity_score_mod(score, b, h, q_idx, kv_idx):
|
||||
return score
|
||||
|
||||
output = flex_attention(query, key, value, score_mod=batch_specific_score_mod)
|
||||
output_identity = flex_attention(
|
||||
query, key, value, score_mod=identity_score_mod
|
||||
)
|
||||
|
||||
# Guard against a bug where the batch index passed to score_mod is always b == 0.
|
||||
# Output would be equivalent to applying an identity score_mod.
|
||||
# See https://github.com/pytorch/pytorch/issues/143788
|
||||
self.assertFalse(torch.allclose(output._values, output_identity._values))
|
||||
|
||||
# Test with mask_mod
|
||||
mask_mod_table = score_mod_table > 0.0
|
||||
|
||||
def my_mask_mod(b, h, q_idx, kv_idx):
|
||||
return mask_mod_table[q_idx]
|
||||
|
||||
def my_mask_mod2(b, h, q_idx, kv_idx):
|
||||
return mask_mod_table[q_idx] & (b == 0)
|
||||
|
||||
block_mask = create_nested_block_mask(my_mask_mod, 1, 1, query, _compile=True)
|
||||
output = flex_attention(query, key, value, block_mask=block_mask)
|
||||
|
||||
block_mask2 = create_nested_block_mask(my_mask_mod2, 1, 1, query, _compile=True)
|
||||
output2 = flex_attention(query, key, value, block_mask=block_mask2)
|
||||
|
||||
# Guard against a bug where the batch index passed to mask_mod is always b == 0.
|
||||
# See https://github.com/pytorch/pytorch/issues/143788
|
||||
self.assertFalse(torch.allclose(output._values, output2._values))
|
||||
|
||||
@dtypes(torch.float32)
|
||||
def test_apply_(self, device, dtype):
|
||||
nt = random_nt_from_dims(
|
||||
|
@ -505,13 +505,6 @@ def flex_attention_fake_impl(
|
||||
):
|
||||
return NotImplemented
|
||||
|
||||
# TODO: Figure out a better way to handle this for NJT than using sum()
|
||||
if query.is_nested:
|
||||
out = torch.empty_like(query, memory_format=torch.contiguous_format)
|
||||
logsumexp = query.sum(dim=-1)
|
||||
max_scores = query.max(dim=-1)[0]
|
||||
return out, logsumexp, max_scores
|
||||
|
||||
v_head_dim = value.size(-1)
|
||||
batch_size, num_heads, seq_len_q, _q_head_dim = query.shape
|
||||
logsumexp = query.new_empty(batch_size, num_heads, seq_len_q, dtype=torch.float32)
|
||||
|
@ -2665,144 +2665,6 @@ def matmul_backward_default(func, *args, **kwargs):
|
||||
return (grad_self, grad_other)
|
||||
|
||||
|
||||
from torch._higher_order_ops.flex_attention import (
|
||||
flex_attention as flex_attention_hop,
|
||||
flex_attention_backward as flex_attention_backward_hop,
|
||||
)
|
||||
from torch.fx.graph_module import GraphModule
|
||||
|
||||
|
||||
@flex_attention_hop.py_impl(NestedTensor) # type: ignore[misc]
|
||||
def flex_njt(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
score_mod: Callable,
|
||||
block_mask: Tuple,
|
||||
scale: float,
|
||||
kernel_options: Dict[str, Any],
|
||||
score_mod_other_buffers: Tuple = (),
|
||||
mask_mod_other_buffers: Tuple = (),
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
assert query.dim() == 4 and key.dim() == 4 and value.dim() == 4
|
||||
|
||||
# TODO: Support this if needed; determine if NJT buffers need be unwrapped as dense.
|
||||
if any(
|
||||
isinstance(buf, torch.Tensor) and buf.is_nested
|
||||
for buf in score_mod_other_buffers + mask_mod_other_buffers
|
||||
):
|
||||
raise RuntimeError(
|
||||
"flex_attention(): Nested tensor score_mod / mask_mod buffers are not "
|
||||
"currently supported. Please file an issue if this is important to you."
|
||||
)
|
||||
|
||||
# Always set them since 0 sized elements are not handled gracefully
|
||||
kernel_options = {**kernel_options, "OUTPUT_MAX": True, "OUTPUT_LOGSUMEXP": True}
|
||||
|
||||
# need to pass dense tensor of shape (B, n_heads, sum(seq_len), D)
|
||||
output = flex_attention_hop(
|
||||
query.values().unsqueeze(0),
|
||||
key.values().unsqueeze(0),
|
||||
value.values().unsqueeze(0),
|
||||
score_mod=score_mod,
|
||||
block_mask=block_mask,
|
||||
scale=scale,
|
||||
kernel_options=kernel_options,
|
||||
score_mod_other_buffers=score_mod_other_buffers,
|
||||
mask_mod_other_buffers=mask_mod_other_buffers,
|
||||
)
|
||||
|
||||
# wrap outputs as NJT
|
||||
output_njt = torch.nested.nested_tensor_from_jagged(
|
||||
output[0].transpose(1, 2).squeeze(0),
|
||||
query._offsets, # type: ignore[attr-defined]
|
||||
query._lengths, # type: ignore[attr-defined]
|
||||
min_seqlen=query._maybe_min_seqlen, # type: ignore[attr-defined]
|
||||
max_seqlen=query._maybe_max_seqlen, # type: ignore[attr-defined]
|
||||
).transpose(1, 2)
|
||||
|
||||
logsumexp_njt = torch.nested.nested_tensor_from_jagged(
|
||||
output[1].transpose(1, 2).squeeze(0),
|
||||
query._offsets, # type: ignore[attr-defined]
|
||||
query._lengths, # type: ignore[attr-defined]
|
||||
min_seqlen=query._maybe_min_seqlen, # type: ignore[attr-defined]
|
||||
max_seqlen=query._maybe_max_seqlen, # type: ignore[attr-defined]
|
||||
).transpose(1, 2)
|
||||
|
||||
max_scores_njt = torch.nested.nested_tensor_from_jagged(
|
||||
output[2].transpose(1, 2).squeeze(0),
|
||||
query._offsets, # type: ignore[attr-defined]
|
||||
query._lengths, # type: ignore[attr-defined]
|
||||
min_seqlen=query._maybe_min_seqlen, # type: ignore[attr-defined]
|
||||
max_seqlen=query._maybe_max_seqlen, # type: ignore[attr-defined]
|
||||
).transpose(1, 2)
|
||||
|
||||
return (output_njt, logsumexp_njt, max_scores_njt)
|
||||
|
||||
|
||||
@flex_attention_backward_hop.py_impl(NestedTensor) # type: ignore[misc]
|
||||
def flex_njt_backward(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
logsumexp: torch.Tensor,
|
||||
grad_out: torch.Tensor,
|
||||
grad_logsumexp: torch.Tensor,
|
||||
fw_graph: Union[Callable, GraphModule],
|
||||
joint_graph: GraphModule,
|
||||
block_mask: Tuple,
|
||||
scale: float,
|
||||
kernel_options: Dict[str, Any],
|
||||
score_mod_other_buffers: Tuple = (),
|
||||
mask_mod_other_buffers: Tuple = (),
|
||||
) -> Tuple[
|
||||
torch.Tensor, torch.Tensor, torch.Tensor, Tuple[Optional[torch.Tensor], ...]
|
||||
]:
|
||||
output = flex_attention_backward_hop(
|
||||
query.values().unsqueeze(0),
|
||||
key.values().unsqueeze(0),
|
||||
value.values().unsqueeze(0),
|
||||
out=out.values().unsqueeze(0),
|
||||
logsumexp=logsumexp.values().unsqueeze(0),
|
||||
grad_out=grad_out.values().unsqueeze(0),
|
||||
grad_logsumexp=grad_logsumexp.values().unsqueeze(0),
|
||||
fw_graph=fw_graph,
|
||||
joint_graph=joint_graph,
|
||||
block_mask=block_mask,
|
||||
scale=scale,
|
||||
kernel_options=kernel_options,
|
||||
score_mod_other_buffers=score_mod_other_buffers,
|
||||
mask_mod_other_buffers=mask_mod_other_buffers,
|
||||
)
|
||||
|
||||
# wrap grads as NJTs
|
||||
dense_q_grad, dense_k_grad, dense_v_grad, score_mod_other_buffer_grads = output
|
||||
njt_q_grad = torch.nested.nested_tensor_from_jagged(
|
||||
dense_q_grad.transpose(1, 2).squeeze(0),
|
||||
query._offsets, # type: ignore[attr-defined]
|
||||
query._lengths, # type: ignore[attr-defined]
|
||||
min_seqlen=query._maybe_min_seqlen, # type: ignore[attr-defined]
|
||||
max_seqlen=query._maybe_max_seqlen, # type: ignore[attr-defined]
|
||||
).transpose(1, 2)
|
||||
njt_k_grad = torch.nested.nested_tensor_from_jagged(
|
||||
dense_k_grad.transpose(1, 2).squeeze(0),
|
||||
key._offsets, # type: ignore[attr-defined]
|
||||
key._lengths, # type: ignore[attr-defined]
|
||||
min_seqlen=key._maybe_min_seqlen, # type: ignore[attr-defined]
|
||||
max_seqlen=key._maybe_max_seqlen, # type: ignore[attr-defined]
|
||||
).transpose(1, 2)
|
||||
njt_v_grad = torch.nested.nested_tensor_from_jagged(
|
||||
dense_v_grad.transpose(1, 2).squeeze(0),
|
||||
value._offsets, # type: ignore[attr-defined]
|
||||
value._lengths, # type: ignore[attr-defined]
|
||||
min_seqlen=value._maybe_min_seqlen, # type: ignore[attr-defined]
|
||||
max_seqlen=value._maybe_max_seqlen, # type: ignore[attr-defined]
|
||||
).transpose(1, 2)
|
||||
|
||||
return (njt_q_grad, njt_k_grad, njt_v_grad, score_mod_other_buffer_grads)
|
||||
|
||||
|
||||
# Make the dummy available on the C++ side.
|
||||
@register_jagged_func(torch.ops.aten._nested_get_jagged_dummy.default, "self: any")
|
||||
def _nested_get_jagged_dummy(func, *args, **kwargs):
|
||||
|
@ -74,7 +74,6 @@ __all__ = [
|
||||
"FlexKernelOptions",
|
||||
"create_block_mask",
|
||||
"create_mask",
|
||||
"create_nested_block_mask",
|
||||
"or_masks",
|
||||
"and_masks",
|
||||
"noop_mask",
|
||||
@ -1111,179 +1110,6 @@ def _create_empty_block_mask(query: Tensor, key: Tensor) -> BlockMask:
|
||||
)
|
||||
|
||||
|
||||
def _nested_mod_func_adapter(
|
||||
orig_mod_func: Union[_score_mod_signature, _mask_mod_signature],
|
||||
q_nt: torch.Tensor,
|
||||
kv_nt: torch.Tensor,
|
||||
is_score_mod: bool,
|
||||
) -> Union[_score_mod_signature, _mask_mod_signature]:
|
||||
r"""Adapter to convert a score_mod / mask_mod to be NJT-compatible. The given mod func
|
||||
should be written as if operating over a single sequence at a item. This adapter will
|
||||
handle conversion from indices operating over a "stacked sequence" of length ``sum(S)``
|
||||
for sequence length ``S`` in the NJT to "sequence relative" indices in range ``[0, S)``.
|
||||
|
||||
Args:
|
||||
orig_mod_func (Callable): Function to modify attention scores. It takes four or five
|
||||
arguments, depending on whether a mask_mod or score_mod func is passed.
|
||||
q_nt (torch.Tensor): Jagged layout nested tensor (NJT) that defines the sequence length
|
||||
structure for query.
|
||||
kv_nt (torch.Tensor): Jagged layout nested tensor (NJT) that defines the sequence length
|
||||
structure for key / value.
|
||||
is_score_mod (bool): Indicates whether the mod function is a score_mod.
|
||||
|
||||
Returns:
|
||||
nt_score_mod: An NJT-compatible version of orig_score_mod
|
||||
"""
|
||||
|
||||
# Used to convert indices within the "stacked" sequence (range [0, sum(*)))
|
||||
# to "sequence local" indices (range [0, S) for each S).
|
||||
def _build_seq_idx(offsets, total_length):
|
||||
range_tensor = torch.arange(
|
||||
total_length, device=offsets.device, dtype=torch.int32
|
||||
)
|
||||
|
||||
# Use searchsorted to find the index for each position
|
||||
# NB: This assumes offsets[0] to offsets[-1] spans the packed dim of values.
|
||||
# If we ever loosen this restriction, this logic will need to be updated.
|
||||
seq_idx = torch.searchsorted(offsets, range_tensor, right=True) - 1
|
||||
return seq_idx
|
||||
|
||||
q_offsets = q_nt._offsets # type: ignore[attr-defined]
|
||||
kv_offsets = kv_nt._offsets # type: ignore[attr-defined]
|
||||
q_seq_idx = _build_seq_idx(q_offsets, q_nt._values.shape[q_nt._ragged_idx - 1]) # type: ignore[attr-defined]
|
||||
if q_nt is kv_nt:
|
||||
kv_seq_idx = q_seq_idx
|
||||
else:
|
||||
# cross attention case
|
||||
kv_seq_idx = _build_seq_idx(
|
||||
kv_offsets,
|
||||
kv_nt._values.shape[kv_nt._ragged_idx - 1], # type: ignore[attr-defined]
|
||||
)
|
||||
|
||||
# Converts q_idx / kv_idx from [0, total_length) -> [0, S), where S refers
|
||||
# to the sequence length for each sequence in the NJT, for use in given
|
||||
# score_mod. This allows the user to write a score_mod as if it were
|
||||
# operating on a single sequence and the "stacked sequence" is split
|
||||
# automatically into individual sequences for them.
|
||||
if is_score_mod:
|
||||
|
||||
def nt_score_mod(score, b, h, q_idx, kv_idx):
|
||||
b_nested = q_seq_idx[q_idx]
|
||||
q_nested = q_idx - q_offsets[q_seq_idx[q_idx]]
|
||||
kv_nested = kv_idx - kv_offsets[kv_seq_idx[kv_idx]]
|
||||
is_same_sequence = q_seq_idx[q_idx] == kv_seq_idx[kv_idx]
|
||||
return torch.where(
|
||||
is_same_sequence,
|
||||
orig_mod_func(score, b_nested, h, q_nested, kv_nested), # type: ignore[call-arg]
|
||||
# don't allow inter-sequence attention
|
||||
float("-inf"),
|
||||
)
|
||||
|
||||
return nt_score_mod
|
||||
else:
|
||||
|
||||
def nt_mask_mod(b, h, q_idx, kv_idx):
|
||||
b_nested = q_seq_idx[q_idx]
|
||||
q_nested = q_idx - q_offsets[q_seq_idx[q_idx]]
|
||||
kv_nested = kv_idx - kv_offsets[kv_seq_idx[kv_idx]]
|
||||
# don't allow inter-sequence attention
|
||||
is_same_sequence = q_seq_idx[q_idx] == kv_seq_idx[kv_idx]
|
||||
return orig_mod_func(b_nested, h, q_nested, kv_nested) & is_same_sequence # type: ignore[call-arg]
|
||||
|
||||
return nt_mask_mod
|
||||
|
||||
|
||||
def create_nested_block_mask(
|
||||
mask_mod: _mask_mod_signature,
|
||||
B: Optional[int],
|
||||
H: Optional[int],
|
||||
q_nt: torch.Tensor,
|
||||
kv_nt: Optional[torch.Tensor] = None,
|
||||
BLOCK_SIZE: Union[int, tuple[int, int]] = _DEFAULT_SPARSE_BLOCK_SIZE,
|
||||
_compile=False,
|
||||
) -> BlockMask:
|
||||
r"""This function creates a nested tensor compatible block mask tuple from a mask_mod
|
||||
function. The returned BlockMask will be on the device specified by the input nested tensor.
|
||||
|
||||
Args:
|
||||
mask_mod (Callable): mask_mod function. This is a callable that defines the
|
||||
masking pattern for the attention mechanism. It takes four arguments:
|
||||
b (batch size), h (number of heads), q_idx (query index), and kv_idx (key/value index).
|
||||
It should return a boolean tensor indicating which attention connections are allowed
|
||||
(True) or masked out (False).
|
||||
B (int): Batch size.
|
||||
H (int): Number of query heads.
|
||||
q_nt (torch.Tensor): Jagged layout nested tensor (NJT) that defines the sequence length
|
||||
structure for query. The block mask will be constructed to operate on a "stacked
|
||||
sequence" of length ``sum(S)`` for sequence length ``S`` from the NJT.
|
||||
kv_nt (torch.Tensor): Jagged layout nested tensor (NJT) that defines the sequence length
|
||||
structure for key / value, allowing for cross attention. The block mask will be
|
||||
constructed to operate on a "stacked sequence" of length ``sum(S)`` for sequence
|
||||
length ``S`` from the NJT. If this is None, ``q_nt`` is used to define the structure
|
||||
for key / value as well. Default: None
|
||||
BLOCK_SIZE (int or tuple[int, int]): Block size for the block mask. If a single int is
|
||||
provided it is used for both query and key/value.
|
||||
|
||||
Returns:
|
||||
BlockMask: A BlockMask object that contains the block mask information.
|
||||
|
||||
Example Usage:
|
||||
.. code-block:: python
|
||||
|
||||
# shape (B, num_heads, seq_len*, D) where seq_len* varies across the batch
|
||||
query = torch.nested.nested_tensor(..., layout=torch.jagged)
|
||||
key = torch.nested.nested_tensor(..., layout=torch.jagged)
|
||||
value = torch.nested.nested_tensor(..., layout=torch.jagged)
|
||||
|
||||
|
||||
def causal_mask(b, h, q_idx, kv_idx):
|
||||
return q_idx >= kv_idx
|
||||
|
||||
|
||||
block_mask = create_nested_block_mask(
|
||||
causal_mask, 1, 1, query, _compile=True
|
||||
)
|
||||
output = flex_attention(query, key, value, block_mask=block_mask)
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# shape (B, num_heads, seq_len*, D) where seq_len* varies across the batch
|
||||
query = torch.nested.nested_tensor(..., layout=torch.jagged)
|
||||
key = torch.nested.nested_tensor(..., layout=torch.jagged)
|
||||
value = torch.nested.nested_tensor(..., layout=torch.jagged)
|
||||
|
||||
|
||||
def causal_mask(b, h, q_idx, kv_idx):
|
||||
return q_idx >= kv_idx
|
||||
|
||||
|
||||
# cross attention case: pass both query and key/value NJTs
|
||||
block_mask = create_nested_block_mask(
|
||||
causal_mask, 1, 1, query, key, _compile=True
|
||||
)
|
||||
output = flex_attention(query, key, value, block_mask=block_mask)
|
||||
"""
|
||||
# use same structure for kv as for q by default
|
||||
if kv_nt is None:
|
||||
kv_nt = q_nt
|
||||
if q_nt.device != kv_nt.device:
|
||||
raise ValueError(
|
||||
"create_nested_block_mask(): Expected q_nt and kv_nt to be on the same device"
|
||||
)
|
||||
return create_block_mask(
|
||||
_nested_mod_func_adapter(mask_mod, q_nt, kv_nt, is_score_mod=False), # type: ignore[arg-type]
|
||||
B,
|
||||
H,
|
||||
q_nt._values.shape[q_nt._ragged_idx - 1], # type: ignore[attr-defined]
|
||||
kv_nt._values.shape[kv_nt._ragged_idx - 1], # type: ignore[attr-defined]
|
||||
device=q_nt.device, # type: ignore[arg-type]
|
||||
# compile is important so we don't materialize a mask_tensor of
|
||||
# shape (1, 1, total_seqlen, total_seqlen)
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
_compile=_compile,
|
||||
)
|
||||
|
||||
|
||||
def _apply_kernel_options(
|
||||
query: Tensor,
|
||||
key: Tensor,
|
||||
@ -1359,25 +1185,6 @@ def _validate_device(query: Tensor, key: Tensor, value: Tensor):
|
||||
)
|
||||
|
||||
|
||||
def _validate_nestedness(query: Tensor, key: Tensor, value: Tensor):
|
||||
# Currently, inputs can only be all nested or no nested.
|
||||
if query.is_nested != key.is_nested or key.is_nested != value.is_nested:
|
||||
raise ValueError(
|
||||
"FlexAttention does not support mixed nested tensor / non-nested tensor inputs. "
|
||||
"Please file an issue requesting this if it is important to you."
|
||||
)
|
||||
|
||||
if (
|
||||
(query.is_nested and query._lengths is not None) # type: ignore[attr-defined]
|
||||
or (key.is_nested and key._lengths is not None) # type: ignore[attr-defined]
|
||||
or (value.is_nested and value._lengths is not None) # type: ignore[attr-defined]
|
||||
):
|
||||
raise ValueError(
|
||||
"FlexAttention does not support nested tensors that are non-contiguous with holes. "
|
||||
"Please file an issue requesting this if it is important to you."
|
||||
)
|
||||
|
||||
|
||||
def _enforce_mem_layouts(
|
||||
query: Tensor, key: Tensor, value: Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
@ -1517,7 +1324,6 @@ def flex_attention(
|
||||
_validate_sdpa_input(query, key, value)
|
||||
_validate_embed_dim(query, key, value)
|
||||
_validate_device(query, key, value)
|
||||
_validate_nestedness(query, key, value)
|
||||
query, key, value = _enforce_mem_layouts(query, key, value)
|
||||
if query.dim() != 4 or key.dim() != 4 or value.dim() != 4:
|
||||
raise NotImplementedError("NYI: query, key, and value must be 4D tensors")
|
||||
@ -1552,14 +1358,6 @@ def flex_attention(
|
||||
|
||||
if score_mod is None:
|
||||
score_mod = _identity
|
||||
elif query.is_nested:
|
||||
# use same NJT if the ragged structures for sequence lengths match between q and kv
|
||||
kv = (
|
||||
query
|
||||
if query.size(query._ragged_idx) == key.size(query._ragged_idx) # type: ignore[attr-defined]
|
||||
else key
|
||||
)
|
||||
score_mod = _nested_mod_func_adapter(score_mod, query, kv, is_score_mod=True) # type: ignore[assignment]
|
||||
|
||||
if block_mask is None:
|
||||
block_mask = _create_empty_block_mask(query, key)
|
||||
@ -1570,12 +1368,6 @@ def flex_attention(
|
||||
):
|
||||
# This corresponds to the case where we essentially have a "no-op" block mask.
|
||||
pass
|
||||
elif query.is_nested:
|
||||
if block_mask.shape[-2] != query._values.size(query._ragged_idx - 1): # type: ignore[attr-defined]
|
||||
raise RuntimeError(
|
||||
f"block_mask of shape {block_mask.shape} is not compatible with nested tensor input "
|
||||
f"with total sequence length of {query._values.size(query._ragged_idx - 1)}" # type: ignore[attr-defined]
|
||||
)
|
||||
else:
|
||||
block_mask_q_len = block_mask.shape[-2]
|
||||
block_mask_kv_len = block_mask.shape[-1]
|
||||
|
Reference in New Issue
Block a user