mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
FlexAttention support for NJT (#136792)
This PR adds FlexAttention + NJT support. In particular: * To handle raggedness, treats the packed sequence dim of input NJTs as a giant "stacked sequence". To ensure user `score_mod` / `mask_mod` functions can still be written in the original NJT sequence space, this PR handles conversions for indices within the giant "stacked sequence" -> sequence relative indices automatically. * Provides `py_impls` for `NestedTensor` to the HOPs for flex attention forward / backward that simply wrap / unwrap NJTs appropriately * Adds barebones `new_empty()` support to NJT since FlexAttention utilizes this repeatedly; right now, only `new_empty()` with a shape of `()` is supported * Tests that FlexAttention with a causal mask matches causal SDPA * Adds a new public API for FlexAttention usage: * `create_nested_block_mask(mask_mod, B, H, njt, BLOCK_SIZE, _compile)` - NJT analogue for `create_block_mask()` that utilizes the `njt`'s ragged structure to create an appropriately-sized block mask (e.g. `(1, 1, total_seqlen, total_seqlen)`). This function handles the index conversion from "stacked sequence" space -> relative sequence space. * Minor note: as this is a public API, this function is purposefully named with "nested" instead of "njt" to keep the latter as an informal, mostly internal-only term. Example usage: ```python def causal_mask(b, h, q_idx, kv_idx): return q_idx >= kv_idx query = ... # NJT of shape (B, H, S*, D) key = ... # NJT of shape (B, H, S*, D) value = ... # NJT of shape (B, H, S*, D) # create_nested_block_mask() automatically converts indices from "stacked sequence" space -> relative sequence space block_mask = create_nested_block_mask(causal_mask, 1, 1, query) # block mask conceptual shape is (B, H, sum(S*), sum(S*)) output = flex_attention(query, key, value, block_mask=block_mask) def causal_score_mod(score, b, h, q_idx, kv_idx): return torch.where(q_idx >= kv_idx, score, float("-inf")) # flex_attention() automatically converts indices from "stacked sequence" space -> relative sequence space for NJT inputs output2 = flex_attention(query, key, value, score_mod=causal_score_mod) ``` TODO: * ~~Determine the right level of abstraction for public API helpers + move them alongside other helpers~~ Verify this with others though * ~~Some cleanup~~ * ~~`njt_score_mod_adapter`~~ * ~~Q: should `create_njt_block_mask()` call `njt_mask_mod_adapter()` so we don't need two calls?~~ * Can we avoid materializing the `sum(s)` length `seq_idx` used for conversion between stacked sequence -> sequence relative indices? * Not for now, although future work may deepen the integration between Flex + NJT (possibly requiring custom templates). We should try to cache this though. * ~~Demonstrate non-causal mask~~ * Support non-contiguous NJTs with holes (**booted to future PR**) Pull Request resolved: https://github.com/pytorch/pytorch/pull/136792 Approved by: https://github.com/drisspg ghstack dependencies: #138841
This commit is contained in:
committed by
PyTorch MergeBot
parent
4cd985a886
commit
8ba9063002
@ -14,6 +14,7 @@ BlockMask Utilities
|
||||
|
||||
.. autofunction:: create_block_mask
|
||||
.. autofunction:: create_mask
|
||||
.. autofunction:: create_nested_block_mask
|
||||
.. autofunction:: and_masks
|
||||
.. autofunction:: or_masks
|
||||
.. autofunction:: noop_mask
|
||||
|
@ -32,18 +32,13 @@ from torch.nn.attention.flex_attention import (
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal import common_utils
|
||||
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_BF16, TEST_MULTIGPU
|
||||
from torch.testing._internal.common_device_type import (
|
||||
flex_attention_supported_platform as supported_platform,
|
||||
)
|
||||
from torch.testing._internal.common_utils import TEST_WITH_ROCM
|
||||
from torch.utils._triton import has_triton
|
||||
|
||||
|
||||
# Skip tests if Triton is not available
|
||||
supported_platform = skipUnless(
|
||||
torch.cuda.is_available()
|
||||
and has_triton()
|
||||
and torch.cuda.get_device_capability() >= (8, 0),
|
||||
"Requires CUDA and Triton",
|
||||
)
|
||||
|
||||
# Use this decorator only when hitting Triton bugs on H100
|
||||
running_on_a100_only = skipUnless(
|
||||
torch.cuda.is_available()
|
||||
|
@ -4,6 +4,7 @@ import ast
|
||||
import io
|
||||
import itertools
|
||||
import math
|
||||
import random
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
@ -25,6 +26,7 @@ from torch.nested._internal.nested_tensor import (
|
||||
NestedTensor,
|
||||
ViewNestedFromBuffer,
|
||||
)
|
||||
from torch.nn.attention.flex_attention import create_nested_block_mask, flex_attention
|
||||
from torch.testing._internal.common_cuda import (
|
||||
PLATFORM_SUPPORTS_FUSED_ATTENTION,
|
||||
SM70OrLater,
|
||||
@ -33,6 +35,7 @@ from torch.testing._internal.common_cuda import (
|
||||
from torch.testing._internal.common_device_type import (
|
||||
dtypes,
|
||||
dtypesIfCUDA,
|
||||
flex_attention_supported_platform,
|
||||
instantiate_device_type_tests,
|
||||
onlyCPU,
|
||||
onlyCUDA,
|
||||
@ -7023,6 +7026,121 @@ class TestNestedTensorSubclass(NestedTensorTestCase):
|
||||
self.assertTrue(torch.allclose(attn_output_eager, attn_output))
|
||||
self.assertTrue(torch.allclose(value_grad, value.grad))
|
||||
|
||||
# Helper function to generate random query, key, value NJTs in (B, n_heads, *, D) format.
|
||||
# If noncontig_with_holes is True, the results will be non-contiguous with holes (i.e. have
|
||||
# both offsets and lengths specified).
|
||||
def _rand_qkv(self, device, dtype, noncontig_with_holes=False):
|
||||
batch_size = 8
|
||||
n_heads = 8
|
||||
D = 16
|
||||
|
||||
sentence_lengths = [random.randint(2, 1023) for _ in range(batch_size - 1)]
|
||||
total = sum(sentence_lengths)
|
||||
|
||||
# shape (B, *, D_total) where D_total = n_heads * D
|
||||
query = torch.nested.nested_tensor(
|
||||
[
|
||||
torch.randn(l, n_heads * D, device=device, dtype=dtype)
|
||||
for l in sentence_lengths
|
||||
],
|
||||
layout=torch.jagged,
|
||||
)
|
||||
if noncontig_with_holes:
|
||||
query = torch.nested.nested_tensor_from_jagged(
|
||||
query._values,
|
||||
query._offsets,
|
||||
# -1 to introduce holes
|
||||
lengths=query._offsets.diff() - 1,
|
||||
jagged_dim=query._ragged_idx,
|
||||
min_seqlen=query._min_seqlen,
|
||||
max_seqlen=query._max_seqlen,
|
||||
)
|
||||
# NB: randn_like() doesn't propagate lengths so this doesn't preserve non-contiguity
|
||||
key = torch.randn_like(query)
|
||||
value = torch.randn_like(query)
|
||||
|
||||
# shape (B, *, D_total) -> (B, n_heads, *, D)
|
||||
query = (
|
||||
query.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
|
||||
)
|
||||
key = key.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
|
||||
value = (
|
||||
value.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
|
||||
)
|
||||
|
||||
return query, key, value
|
||||
|
||||
@onlyCUDA
|
||||
@flex_attention_supported_platform
|
||||
@dtypes(torch.float32)
|
||||
# non-contiguous with holes not supported yet
|
||||
@decorateIf(unittest.skip, lambda params: params["noncontig_with_holes"])
|
||||
@parametrize("noncontig_with_holes", [False, True])
|
||||
def test_flex_attention(self, device, dtype, noncontig_with_holes):
|
||||
query, key, value = self._rand_qkv(device, dtype, noncontig_with_holes)
|
||||
|
||||
# Run FlexAttention with a causal mask
|
||||
def causal_mask(b, h, q_idx, kv_idx):
|
||||
return q_idx >= kv_idx
|
||||
|
||||
block_mask = create_nested_block_mask(causal_mask, 1, 1, query, _compile=True)
|
||||
out_flex = flex_attention(query, key, value, block_mask=block_mask)
|
||||
grad_out = torch.randn_like(out_flex)
|
||||
grads_flex = torch.autograd.grad(
|
||||
out_flex, inputs=(query, key, value), grad_outputs=(grad_out,)
|
||||
)
|
||||
flex_outs = [out_flex, *grads_flex]
|
||||
|
||||
# Run FlexAttention with a score_mod that represents causal attention
|
||||
def causal_score_mod(score, b, h, q_idx, kv_idx):
|
||||
return torch.where(q_idx >= kv_idx, score, float("-inf"))
|
||||
|
||||
out_flex2 = flex_attention(query, key, value, score_mod=causal_score_mod)
|
||||
grads_flex2 = torch.autograd.grad(
|
||||
out_flex2, inputs=(query, key, value), grad_outputs=(grad_out,)
|
||||
)
|
||||
flex_outs2 = [out_flex2, *grads_flex2]
|
||||
|
||||
# Run causal SDPA for comparison
|
||||
out_sdpa = F.scaled_dot_product_attention(query, key, value, is_causal=True)
|
||||
grads_sdpa = torch.autograd.grad(
|
||||
out_sdpa, inputs=(query, key, value), grad_outputs=(grad_out,)
|
||||
)
|
||||
sdpa_outs = [out_sdpa, *grads_sdpa]
|
||||
|
||||
# Compare flex vs. SDPA output and grads
|
||||
for flex, flex2, sdpa in zip(flex_outs, flex_outs2, sdpa_outs):
|
||||
self.assertTrue(flex.is_nested and flex2.is_nested and sdpa.is_nested)
|
||||
self.assertEqual(flex, sdpa, atol=1e-2, rtol=1e-2)
|
||||
self.assertEqual(flex2, sdpa, atol=1e-2, rtol=1e-2)
|
||||
|
||||
@onlyCUDA
|
||||
@flex_attention_supported_platform
|
||||
@dtypes(torch.float32)
|
||||
def test_flex_attention_converts_stacked_seq_indices(self, device, dtype):
|
||||
# This test verifies that a score_mod function written to operate within
|
||||
# NJT sequence index space, such as a lookup table, works correctly. This
|
||||
# validates that FlexAttention properly converts indices within the
|
||||
# "stacked sequence" space used for NJT -> sequence-relative indices.
|
||||
query, key, value = self._rand_qkv(device, dtype)
|
||||
|
||||
# Test with score_mod
|
||||
score_mod_table = torch.randn(query._max_seqlen, device=device, dtype=dtype)
|
||||
|
||||
def my_score_mod(score, b, h, q_idx, kv_idx):
|
||||
return score_mod_table[q_idx]
|
||||
|
||||
flex_attention(query, key, value, score_mod=my_score_mod)
|
||||
|
||||
# Test with mask_mod
|
||||
mask_mod_table = score_mod_table > 0.0
|
||||
|
||||
def my_mask_mod(b, h, q_idx, kv_idx):
|
||||
return mask_mod_table[q_idx]
|
||||
|
||||
block_mask = create_nested_block_mask(my_mask_mod, 1, 1, query, _compile=True)
|
||||
flex_attention(query, key, value, block_mask=block_mask)
|
||||
|
||||
@dtypes(torch.float32)
|
||||
def test_apply_(self, device, dtype):
|
||||
nt = random_nt_from_dims(
|
||||
|
@ -2195,7 +2195,8 @@ class FlexAttentionHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
out_meta = torch.empty_like(
|
||||
query_meta, memory_format=torch.contiguous_format
|
||||
)
|
||||
lse_meta = query_meta.new_empty(logsumexp_shape, dtype=torch.float32)
|
||||
# TODO: Figure out a better way to handle this for NJT than using sum()
|
||||
lse_meta = torch.empty_like(query_meta, dtype=torch.float32).sum(dim=-1)
|
||||
example_value = (out_meta, lse_meta)
|
||||
|
||||
# Compose the ordered HOO args:
|
||||
|
@ -1962,6 +1962,150 @@ def record_stream_default(func, *args, **kwargs):
|
||||
func(inp._lengths, stream)
|
||||
|
||||
|
||||
@register_jagged_func(
|
||||
torch.ops.aten.new_empty.default,
|
||||
"self: jt_all, size: any, dtype: any?, layout: any?, device: any?, pin_memory: any?",
|
||||
)
|
||||
def new_empty_default(func, *args, **kwargs):
|
||||
_, new_kwargs = normalize_function( # type: ignore[misc]
|
||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||
)
|
||||
|
||||
inp = new_kwargs.pop("input")
|
||||
|
||||
if len(new_kwargs["size"]) == 0:
|
||||
return func(inp._values, **new_kwargs)
|
||||
|
||||
raise RuntimeError("new_empty() not supported for NJT with shape != ()")
|
||||
|
||||
|
||||
from torch._higher_order_ops.flex_attention import (
|
||||
flex_attention as flex_attention_hop,
|
||||
flex_attention_backward as flex_attention_backward_hop,
|
||||
)
|
||||
from torch.fx.graph_module import GraphModule
|
||||
|
||||
|
||||
@flex_attention_hop.py_impl(NestedTensor) # type: ignore[misc]
|
||||
def flex_njt(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
score_mod: Callable,
|
||||
block_mask: Tuple,
|
||||
scale: float,
|
||||
kernel_options: Dict[str, Any],
|
||||
score_mod_other_buffers: Tuple = (),
|
||||
mask_mod_other_buffers: Tuple = (),
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
assert query.dim() == 4 and key.dim() == 4 and value.dim() == 4
|
||||
|
||||
# TODO: Support this if needed; determine if NJT buffers need be unwrapped as dense.
|
||||
if any(
|
||||
isinstance(buf, torch.Tensor) and buf.is_nested
|
||||
for buf in score_mod_other_buffers + mask_mod_other_buffers
|
||||
):
|
||||
raise RuntimeError(
|
||||
"flex_attention(): Nested tensor score_mod / mask_mod buffers are not "
|
||||
"currently supported. Please file an issue if this is important to you."
|
||||
)
|
||||
|
||||
# need to pass dense tensor of shape (B, n_heads, sum(seq_len), D)
|
||||
output = flex_attention_hop(
|
||||
query.values().unsqueeze(0),
|
||||
key.values().unsqueeze(0),
|
||||
value.values().unsqueeze(0),
|
||||
score_mod=score_mod,
|
||||
block_mask=block_mask,
|
||||
scale=scale,
|
||||
kernel_options=kernel_options,
|
||||
score_mod_other_buffers=score_mod_other_buffers,
|
||||
mask_mod_other_buffers=mask_mod_other_buffers,
|
||||
)
|
||||
|
||||
# wrap outputs as NJT
|
||||
output_njt = torch.nested.nested_tensor_from_jagged(
|
||||
output[0].transpose(1, 2).squeeze(0),
|
||||
query._offsets, # type: ignore[attr-defined]
|
||||
query._lengths, # type: ignore[attr-defined]
|
||||
min_seqlen=query._maybe_min_seqlen, # type: ignore[attr-defined]
|
||||
max_seqlen=query._maybe_max_seqlen, # type: ignore[attr-defined]
|
||||
).transpose(1, 2)
|
||||
|
||||
logsumexp_njt = torch.nested.nested_tensor_from_jagged(
|
||||
output[1].transpose(1, 2).squeeze(0),
|
||||
query._offsets, # type: ignore[attr-defined]
|
||||
query._lengths, # type: ignore[attr-defined]
|
||||
min_seqlen=query._maybe_min_seqlen, # type: ignore[attr-defined]
|
||||
max_seqlen=query._maybe_max_seqlen, # type: ignore[attr-defined]
|
||||
).transpose(1, 2)
|
||||
|
||||
return (output_njt, logsumexp_njt)
|
||||
|
||||
|
||||
@flex_attention_backward_hop.py_impl(NestedTensor) # type: ignore[misc]
|
||||
def flex_njt_backward(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
logsumexp: torch.Tensor,
|
||||
grad_out: torch.Tensor,
|
||||
grad_logsumexp: torch.Tensor,
|
||||
fw_graph: Union[Callable, GraphModule],
|
||||
joint_graph: GraphModule,
|
||||
block_mask: Tuple,
|
||||
scale: float,
|
||||
kernel_options: Dict[str, Any],
|
||||
score_mod_other_buffers: Tuple = (),
|
||||
mask_mod_other_buffers: Tuple = (),
|
||||
) -> Tuple[
|
||||
torch.Tensor, torch.Tensor, torch.Tensor, Tuple[Optional[torch.Tensor], ...]
|
||||
]:
|
||||
output = flex_attention_backward_hop(
|
||||
query.values().unsqueeze(0),
|
||||
key.values().unsqueeze(0),
|
||||
value.values().unsqueeze(0),
|
||||
out=out.values().unsqueeze(0),
|
||||
logsumexp=logsumexp.values().unsqueeze(0),
|
||||
grad_out=grad_out.values().unsqueeze(0),
|
||||
grad_logsumexp=grad_logsumexp.values().unsqueeze(0),
|
||||
fw_graph=fw_graph,
|
||||
joint_graph=joint_graph,
|
||||
block_mask=block_mask,
|
||||
scale=scale,
|
||||
kernel_options=kernel_options,
|
||||
score_mod_other_buffers=score_mod_other_buffers,
|
||||
mask_mod_other_buffers=mask_mod_other_buffers,
|
||||
)
|
||||
|
||||
# wrap grads as NJTs
|
||||
dense_q_grad, dense_k_grad, dense_v_grad, score_mod_other_buffer_grads = output
|
||||
njt_q_grad = torch.nested.nested_tensor_from_jagged(
|
||||
dense_q_grad.transpose(1, 2).squeeze(0),
|
||||
query._offsets, # type: ignore[attr-defined]
|
||||
query._lengths, # type: ignore[attr-defined]
|
||||
min_seqlen=query._maybe_min_seqlen, # type: ignore[attr-defined]
|
||||
max_seqlen=query._maybe_max_seqlen, # type: ignore[attr-defined]
|
||||
).transpose(1, 2)
|
||||
njt_k_grad = torch.nested.nested_tensor_from_jagged(
|
||||
dense_k_grad.transpose(1, 2).squeeze(0),
|
||||
key._offsets, # type: ignore[attr-defined]
|
||||
key._lengths, # type: ignore[attr-defined]
|
||||
min_seqlen=key._maybe_min_seqlen, # type: ignore[attr-defined]
|
||||
max_seqlen=key._maybe_max_seqlen, # type: ignore[attr-defined]
|
||||
).transpose(1, 2)
|
||||
njt_v_grad = torch.nested.nested_tensor_from_jagged(
|
||||
dense_v_grad.transpose(1, 2).squeeze(0),
|
||||
value._offsets, # type: ignore[attr-defined]
|
||||
value._lengths, # type: ignore[attr-defined]
|
||||
min_seqlen=value._maybe_min_seqlen, # type: ignore[attr-defined]
|
||||
max_seqlen=value._maybe_max_seqlen, # type: ignore[attr-defined]
|
||||
).transpose(1, 2)
|
||||
|
||||
return (njt_q_grad, njt_k_grad, njt_v_grad, score_mod_other_buffer_grads)
|
||||
|
||||
|
||||
# Make the dummy available on the C++ side.
|
||||
@register_jagged_func(torch.ops.aten._nested_get_jagged_dummy.default, "self: any")
|
||||
def _nested_get_jagged_dummy(func, *args, **kwargs):
|
||||
|
@ -29,6 +29,7 @@ __all__ = [
|
||||
"flex_attention",
|
||||
"create_block_mask",
|
||||
"create_mask",
|
||||
"create_nested_block_mask",
|
||||
"or_masks",
|
||||
"and_masks",
|
||||
"noop_mask",
|
||||
@ -895,6 +896,130 @@ def _create_empty_block_mask(query: Tensor, key: Tensor) -> BlockMask:
|
||||
)
|
||||
|
||||
|
||||
def _nested_mod_func_adapter(
|
||||
orig_mod_func: Union[_score_mod_signature, _mask_mod_signature],
|
||||
nt: torch.Tensor,
|
||||
is_score_mod: bool,
|
||||
) -> Union[_score_mod_signature, _mask_mod_signature]:
|
||||
r"""Adapter to convert a score_mod / mask_mod to be NJT-compatible. The given mod func
|
||||
should be written as if operating over a single sequence at a item. This adapter will
|
||||
handle conversion from indices operating over a "stacked sequence" of length ``sum(S)``
|
||||
for sequence length ``S`` in the NJT to "sequence relative" indices in range ``[0, S)``.
|
||||
|
||||
Args:
|
||||
orig_mod_func (Callable): Function to modify attention scores. It takes four or five
|
||||
arguments, depending on whether a mask_mod or score_mod func is passed.
|
||||
nt (torch.Tensor): Jagged layout nested tensor (NJT) that defines the sequence length
|
||||
structure for query / key / value.
|
||||
is_score_mod (bool): Indicates whether the mod function is a score_mod.
|
||||
|
||||
Returns:
|
||||
nt_score_mod: An NJT-compatible version of orig_score_mod
|
||||
"""
|
||||
|
||||
# Used to convert indices within the "stacked" sequence (range [0, sum(*)))
|
||||
# to "sequence local" indices (range [0, S) for each S).
|
||||
def _build_seq_idx(offsets, total_length):
|
||||
range_tensor = torch.arange(
|
||||
total_length, device=offsets.device, dtype=torch.int32
|
||||
)
|
||||
|
||||
# Use searchsorted to find the index for each position
|
||||
# NB: This assumes offsets[0] to offsets[-1] spans the packed dim of values.
|
||||
# If we ever loosen this restriction, this logic will need to be updated.
|
||||
seq_idx = torch.searchsorted(offsets, range_tensor, right=True) - 1
|
||||
return seq_idx
|
||||
|
||||
offsets = nt._offsets # type: ignore[attr-defined]
|
||||
total_length = nt._values.shape[nt._ragged_idx - 1] # type: ignore[attr-defined]
|
||||
seq_idx = _build_seq_idx(offsets, total_length)
|
||||
|
||||
# Converts q_idx / kv_idx from [0, total_length) -> [0, S), where S refers
|
||||
# to the sequence length for each sequence in the NJT, for use in given
|
||||
# score_mod. This allows the user to write a score_mod as if it were
|
||||
# operating on a single sequence and the "stacked sequence" is split
|
||||
# automatically into individual sequences for them.
|
||||
if is_score_mod:
|
||||
|
||||
def nt_score_mod(score, b, h, q_idx, kv_idx):
|
||||
q_nested = q_idx - offsets[seq_idx[q_idx]]
|
||||
kv_nested = kv_idx - offsets[seq_idx[kv_idx]]
|
||||
is_same_sequence = seq_idx[q_idx] == seq_idx[kv_idx]
|
||||
return torch.where(
|
||||
is_same_sequence,
|
||||
orig_mod_func(score, b, h, q_nested, kv_nested), # type: ignore[call-arg]
|
||||
# don't allow inter-sequence attention
|
||||
float("-inf"),
|
||||
)
|
||||
|
||||
return nt_score_mod
|
||||
else:
|
||||
|
||||
def nt_mask_mod(b, h, q_idx, kv_idx):
|
||||
q_nested = q_idx - offsets[seq_idx[q_idx]]
|
||||
kv_nested = kv_idx - offsets[seq_idx[kv_idx]]
|
||||
# don't allow inter-sequence attention
|
||||
is_same_sequence = seq_idx[q_idx] == seq_idx[kv_idx]
|
||||
return orig_mod_func(b, h, q_nested, kv_nested) & is_same_sequence # type: ignore[call-arg]
|
||||
|
||||
return nt_mask_mod
|
||||
|
||||
|
||||
def create_nested_block_mask(
|
||||
mask_mod: _mask_mod_signature,
|
||||
B: Optional[int],
|
||||
H: Optional[int],
|
||||
nt: torch.Tensor,
|
||||
BLOCK_SIZE: Union[int, Tuple[int, int]] = _DEFAULT_SPARSE_BLOCK_SIZE,
|
||||
_compile=False,
|
||||
) -> BlockMask:
|
||||
r"""This function creates a nested tensor compatible block mask tuple from a mask_mod
|
||||
function. The returned BlockMask will be on the device specified by the input nested tensor.
|
||||
|
||||
Args:
|
||||
mask_mod (Callable): mask_mod function. This is a callable that defines the
|
||||
masking pattern for the attention mechanism. It takes four arguments:
|
||||
b (batch size), h (number of heads), q_idx (query index), and kv_idx (key/value index).
|
||||
It should return a boolean tensor indicating which attention connections are allowed
|
||||
(True) or masked out (False).
|
||||
B (int): Batch size.
|
||||
H (int): Number of query heads.
|
||||
nt (torch.Tensor): Jagged layout nested tensor (NJT) that defines the sequence length
|
||||
structure for query / key / value. The block mask will be constructed to operate on
|
||||
a "stacked sequence" of length ``sum(S)`` for sequence length ``S`` from the NJT.
|
||||
BLOCK_SIZE (int or Tuple[int, int]): Block size for the block mask. If a single int is
|
||||
provided it is used for both query and key/value.
|
||||
|
||||
Returns:
|
||||
BlockMask: A BlockMask object that contains the block mask information.
|
||||
|
||||
Example Usage:
|
||||
.. code-block:: python
|
||||
|
||||
query = torch.randn(1, 1, 8192, 64, device="cuda", dtype=torch.float16)
|
||||
key = torch.randn(1, 1, 8192, 64, device="cuda", dtype=torch.float16)
|
||||
value = torch.randn(1, 1, 8192, 64, device="cuda", dtype=torch.float16)
|
||||
|
||||
def causal_mask(b, h, q_idx, kv_idx):
|
||||
return q_idx >= kv_idx
|
||||
|
||||
block_mask = create_nested_block_mask(causal_mask, 1, 1, query, _compile=True)
|
||||
output = flex_attention(query, key, value, block_mask=block_mask)
|
||||
"""
|
||||
return create_block_mask(
|
||||
_nested_mod_func_adapter(mask_mod, nt, is_score_mod=False), # type: ignore[arg-type]
|
||||
B,
|
||||
H,
|
||||
nt._values.shape[nt._ragged_idx - 1], # type: ignore[attr-defined]
|
||||
nt._values.shape[nt._ragged_idx - 1], # type: ignore[attr-defined]
|
||||
device=nt.device, # type: ignore[arg-type]
|
||||
# compile is important so we don't materialize a mask_tensor of
|
||||
# shape (1, 1, total_seqlen, total_seqlen)
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
_compile=_compile,
|
||||
)
|
||||
|
||||
|
||||
def _apply_kernel_options(
|
||||
query: Tensor, key: Tensor, value: Tensor, return_lse: bool, kernel_options
|
||||
):
|
||||
@ -944,6 +1069,25 @@ def _validate_device(query: Tensor, key: Tensor, value: Tensor):
|
||||
)
|
||||
|
||||
|
||||
def _validate_nestedness(query: Tensor, key: Tensor, value: Tensor):
|
||||
# Currently, inputs can only be all nested or no nested.
|
||||
if query.is_nested != key.is_nested or key.is_nested != value.is_nested:
|
||||
raise ValueError(
|
||||
"FlexAttention does not support mixed nested tensor / non-nested tensor inputs. "
|
||||
"Please file an issue requesting this if it is important to you."
|
||||
)
|
||||
|
||||
if (
|
||||
(query.is_nested and query._lengths is not None) # type: ignore[attr-defined]
|
||||
or (key.is_nested and key._lengths is not None) # type: ignore[attr-defined]
|
||||
or (value.is_nested and value._lengths is not None) # type: ignore[attr-defined]
|
||||
):
|
||||
raise ValueError(
|
||||
"FlexAttention does not support nested tensors that are non-contiguous with holes. "
|
||||
"Please file an issue requesting this if it is important to you."
|
||||
)
|
||||
|
||||
|
||||
def flex_attention(
|
||||
query: Tensor,
|
||||
key: Tensor,
|
||||
@ -1011,6 +1155,7 @@ def flex_attention(
|
||||
_validate_sdpa_input(query, key, value)
|
||||
_validate_embed_dim(query, key, value)
|
||||
_validate_device(query, key, value)
|
||||
_validate_nestedness(query, key, value)
|
||||
if query.dim() != 4 or key.dim() != 4 or value.dim() != 4:
|
||||
raise NotImplementedError("NYI: query, key, and value must be 4D tensors")
|
||||
if (not enable_gqa) and query.size(-3) != key.size(-3):
|
||||
@ -1030,15 +1175,29 @@ def flex_attention(
|
||||
|
||||
if score_mod is None:
|
||||
score_mod = _identity
|
||||
elif query.is_nested:
|
||||
score_mod = _nested_mod_func_adapter(score_mod, query, is_score_mod=True) # type: ignore[assignment]
|
||||
|
||||
if block_mask is None:
|
||||
block_mask = _create_empty_block_mask(query, key)
|
||||
elif (
|
||||
elif not query.is_nested and (
|
||||
query.size(-2) < block_mask.kv_num_blocks.size(-1) * block_mask.BLOCK_SIZE[0]
|
||||
or key.size(-2) < block_mask.kv_indices.size(-1) * block_mask.BLOCK_SIZE[1]
|
||||
):
|
||||
new_q_len = _round_up_to_multiple(query.size(-2), block_mask.BLOCK_SIZE[0])
|
||||
new_kv_len = _round_up_to_multiple(key.size(-2), block_mask.BLOCK_SIZE[1])
|
||||
block_mask = block_mask._adjust(new_q_len, new_kv_len)
|
||||
elif query.is_nested and (
|
||||
block_mask.kv_num_blocks.size(-1) * block_mask.BLOCK_SIZE[0]
|
||||
!= _round_up_to_multiple(
|
||||
query._values.size(query._ragged_idx - 1), block_mask.BLOCK_SIZE[0] # type: ignore[attr-defined]
|
||||
)
|
||||
):
|
||||
# TODO: Maybe we want to auto-adjust for this case as well?
|
||||
raise RuntimeError(
|
||||
f"block_mask of shape {block_mask.shape} is not compatible with nested tensor input "
|
||||
f"with total sequence length of {query._values.size(query._ragged_idx - 1)}" # type: ignore[attr-defined]
|
||||
)
|
||||
if scale is None:
|
||||
scale = 1.0 / math.sqrt(query.size(-1))
|
||||
|
||||
|
@ -1950,3 +1950,12 @@ def skipPRIVATEUSE1(fn):
|
||||
# This should probably enumerate all available device type test base classes.
|
||||
def get_all_device_types() -> List[str]:
|
||||
return ["cpu"] if not torch.cuda.is_available() else ["cpu", "cuda"]
|
||||
|
||||
|
||||
flex_attention_supported_platform = unittest.skipUnless(
|
||||
torch.cuda.is_available()
|
||||
and torch.version.hip is None
|
||||
and torch.utils._triton.has_triton()
|
||||
and torch.cuda.get_device_capability() >= (8, 0),
|
||||
"Requires CUDA and Triton",
|
||||
)
|
||||
|
Reference in New Issue
Block a user