FlexAttention support for NJT (#136792)

This PR adds FlexAttention + NJT support. In particular:
* To handle raggedness, treats the packed sequence dim of input NJTs as a giant "stacked sequence". To ensure user `score_mod` / `mask_mod` functions can still be written in the original NJT sequence space, this PR handles conversions for indices within the giant "stacked sequence" -> sequence relative indices automatically.
* Provides `py_impls` for `NestedTensor` to the HOPs for flex attention forward / backward that simply wrap / unwrap NJTs appropriately
* Adds barebones `new_empty()` support to NJT since FlexAttention utilizes this repeatedly; right now, only `new_empty()` with a shape of `()` is supported
* Tests that FlexAttention with a causal mask matches causal SDPA
* Adds a new public API for FlexAttention usage:
    * `create_nested_block_mask(mask_mod, B, H, njt, BLOCK_SIZE, _compile)` - NJT analogue for `create_block_mask()` that utilizes the `njt`'s ragged structure to create an appropriately-sized block mask (e.g. `(1, 1, total_seqlen, total_seqlen)`). This function handles the index conversion from "stacked sequence" space -> relative sequence space.
      * Minor note: as this is a public API, this function is purposefully named with "nested" instead of "njt" to keep the latter as an informal, mostly internal-only term.

Example usage:
```python
def causal_mask(b, h, q_idx, kv_idx):
    return q_idx >= kv_idx

query = ... # NJT of shape (B, H, S*, D)
key = ... # NJT of shape (B, H, S*, D)
value = ... # NJT of shape (B, H, S*, D)
# create_nested_block_mask() automatically converts indices from "stacked sequence" space -> relative sequence space
block_mask = create_nested_block_mask(causal_mask, 1, 1, query)  # block mask conceptual shape is (B, H, sum(S*), sum(S*))
output = flex_attention(query, key, value, block_mask=block_mask)

def causal_score_mod(score, b, h, q_idx, kv_idx):
    return torch.where(q_idx >= kv_idx, score, float("-inf"))

# flex_attention() automatically converts indices from "stacked sequence" space -> relative sequence space for NJT inputs
output2 = flex_attention(query, key, value, score_mod=causal_score_mod)
```

TODO:
* ~~Determine the right level of abstraction for public API helpers + move them alongside other helpers~~ Verify this with others though
* ~~Some cleanup~~
* ~~`njt_score_mod_adapter`~~
* ~~Q: should `create_njt_block_mask()` call `njt_mask_mod_adapter()` so we don't need two calls?~~
* Can we avoid materializing the `sum(s)` length `seq_idx` used for conversion between stacked sequence -> sequence relative indices?
    * Not for now, although future work may deepen the integration between Flex + NJT (possibly requiring custom templates). We should try to cache this though.
* ~~Demonstrate non-causal mask~~
* Support non-contiguous NJTs with holes (**booted to future PR**)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136792
Approved by: https://github.com/drisspg
ghstack dependencies: #138841
This commit is contained in:
Joel Schlosser
2024-10-25 12:50:23 -04:00
committed by PyTorch MergeBot
parent 4cd985a886
commit 8ba9063002
7 changed files with 437 additions and 10 deletions

View File

@ -14,6 +14,7 @@ BlockMask Utilities
.. autofunction:: create_block_mask
.. autofunction:: create_mask
.. autofunction:: create_nested_block_mask
.. autofunction:: and_masks
.. autofunction:: or_masks
.. autofunction:: noop_mask

View File

@ -32,18 +32,13 @@ from torch.nn.attention.flex_attention import (
from torch.testing import FileCheck
from torch.testing._internal import common_utils
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_BF16, TEST_MULTIGPU
from torch.testing._internal.common_device_type import (
flex_attention_supported_platform as supported_platform,
)
from torch.testing._internal.common_utils import TEST_WITH_ROCM
from torch.utils._triton import has_triton
# Skip tests if Triton is not available
supported_platform = skipUnless(
torch.cuda.is_available()
and has_triton()
and torch.cuda.get_device_capability() >= (8, 0),
"Requires CUDA and Triton",
)
# Use this decorator only when hitting Triton bugs on H100
running_on_a100_only = skipUnless(
torch.cuda.is_available()

View File

@ -4,6 +4,7 @@ import ast
import io
import itertools
import math
import random
import sys
import tempfile
import unittest
@ -25,6 +26,7 @@ from torch.nested._internal.nested_tensor import (
NestedTensor,
ViewNestedFromBuffer,
)
from torch.nn.attention.flex_attention import create_nested_block_mask, flex_attention
from torch.testing._internal.common_cuda import (
PLATFORM_SUPPORTS_FUSED_ATTENTION,
SM70OrLater,
@ -33,6 +35,7 @@ from torch.testing._internal.common_cuda import (
from torch.testing._internal.common_device_type import (
dtypes,
dtypesIfCUDA,
flex_attention_supported_platform,
instantiate_device_type_tests,
onlyCPU,
onlyCUDA,
@ -7023,6 +7026,121 @@ class TestNestedTensorSubclass(NestedTensorTestCase):
self.assertTrue(torch.allclose(attn_output_eager, attn_output))
self.assertTrue(torch.allclose(value_grad, value.grad))
# Helper function to generate random query, key, value NJTs in (B, n_heads, *, D) format.
# If noncontig_with_holes is True, the results will be non-contiguous with holes (i.e. have
# both offsets and lengths specified).
def _rand_qkv(self, device, dtype, noncontig_with_holes=False):
batch_size = 8
n_heads = 8
D = 16
sentence_lengths = [random.randint(2, 1023) for _ in range(batch_size - 1)]
total = sum(sentence_lengths)
# shape (B, *, D_total) where D_total = n_heads * D
query = torch.nested.nested_tensor(
[
torch.randn(l, n_heads * D, device=device, dtype=dtype)
for l in sentence_lengths
],
layout=torch.jagged,
)
if noncontig_with_holes:
query = torch.nested.nested_tensor_from_jagged(
query._values,
query._offsets,
# -1 to introduce holes
lengths=query._offsets.diff() - 1,
jagged_dim=query._ragged_idx,
min_seqlen=query._min_seqlen,
max_seqlen=query._max_seqlen,
)
# NB: randn_like() doesn't propagate lengths so this doesn't preserve non-contiguity
key = torch.randn_like(query)
value = torch.randn_like(query)
# shape (B, *, D_total) -> (B, n_heads, *, D)
query = (
query.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
)
key = key.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
value = (
value.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
)
return query, key, value
@onlyCUDA
@flex_attention_supported_platform
@dtypes(torch.float32)
# non-contiguous with holes not supported yet
@decorateIf(unittest.skip, lambda params: params["noncontig_with_holes"])
@parametrize("noncontig_with_holes", [False, True])
def test_flex_attention(self, device, dtype, noncontig_with_holes):
query, key, value = self._rand_qkv(device, dtype, noncontig_with_holes)
# Run FlexAttention with a causal mask
def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx
block_mask = create_nested_block_mask(causal_mask, 1, 1, query, _compile=True)
out_flex = flex_attention(query, key, value, block_mask=block_mask)
grad_out = torch.randn_like(out_flex)
grads_flex = torch.autograd.grad(
out_flex, inputs=(query, key, value), grad_outputs=(grad_out,)
)
flex_outs = [out_flex, *grads_flex]
# Run FlexAttention with a score_mod that represents causal attention
def causal_score_mod(score, b, h, q_idx, kv_idx):
return torch.where(q_idx >= kv_idx, score, float("-inf"))
out_flex2 = flex_attention(query, key, value, score_mod=causal_score_mod)
grads_flex2 = torch.autograd.grad(
out_flex2, inputs=(query, key, value), grad_outputs=(grad_out,)
)
flex_outs2 = [out_flex2, *grads_flex2]
# Run causal SDPA for comparison
out_sdpa = F.scaled_dot_product_attention(query, key, value, is_causal=True)
grads_sdpa = torch.autograd.grad(
out_sdpa, inputs=(query, key, value), grad_outputs=(grad_out,)
)
sdpa_outs = [out_sdpa, *grads_sdpa]
# Compare flex vs. SDPA output and grads
for flex, flex2, sdpa in zip(flex_outs, flex_outs2, sdpa_outs):
self.assertTrue(flex.is_nested and flex2.is_nested and sdpa.is_nested)
self.assertEqual(flex, sdpa, atol=1e-2, rtol=1e-2)
self.assertEqual(flex2, sdpa, atol=1e-2, rtol=1e-2)
@onlyCUDA
@flex_attention_supported_platform
@dtypes(torch.float32)
def test_flex_attention_converts_stacked_seq_indices(self, device, dtype):
# This test verifies that a score_mod function written to operate within
# NJT sequence index space, such as a lookup table, works correctly. This
# validates that FlexAttention properly converts indices within the
# "stacked sequence" space used for NJT -> sequence-relative indices.
query, key, value = self._rand_qkv(device, dtype)
# Test with score_mod
score_mod_table = torch.randn(query._max_seqlen, device=device, dtype=dtype)
def my_score_mod(score, b, h, q_idx, kv_idx):
return score_mod_table[q_idx]
flex_attention(query, key, value, score_mod=my_score_mod)
# Test with mask_mod
mask_mod_table = score_mod_table > 0.0
def my_mask_mod(b, h, q_idx, kv_idx):
return mask_mod_table[q_idx]
block_mask = create_nested_block_mask(my_mask_mod, 1, 1, query, _compile=True)
flex_attention(query, key, value, block_mask=block_mask)
@dtypes(torch.float32)
def test_apply_(self, device, dtype):
nt = random_nt_from_dims(

View File

@ -2195,7 +2195,8 @@ class FlexAttentionHigherOrderVariable(TorchHigherOrderOperatorVariable):
out_meta = torch.empty_like(
query_meta, memory_format=torch.contiguous_format
)
lse_meta = query_meta.new_empty(logsumexp_shape, dtype=torch.float32)
# TODO: Figure out a better way to handle this for NJT than using sum()
lse_meta = torch.empty_like(query_meta, dtype=torch.float32).sum(dim=-1)
example_value = (out_meta, lse_meta)
# Compose the ordered HOO args:

View File

@ -1962,6 +1962,150 @@ def record_stream_default(func, *args, **kwargs):
func(inp._lengths, stream)
@register_jagged_func(
torch.ops.aten.new_empty.default,
"self: jt_all, size: any, dtype: any?, layout: any?, device: any?, pin_memory: any?",
)
def new_empty_default(func, *args, **kwargs):
_, new_kwargs = normalize_function( # type: ignore[misc]
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
inp = new_kwargs.pop("input")
if len(new_kwargs["size"]) == 0:
return func(inp._values, **new_kwargs)
raise RuntimeError("new_empty() not supported for NJT with shape != ()")
from torch._higher_order_ops.flex_attention import (
flex_attention as flex_attention_hop,
flex_attention_backward as flex_attention_backward_hop,
)
from torch.fx.graph_module import GraphModule
@flex_attention_hop.py_impl(NestedTensor) # type: ignore[misc]
def flex_njt(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
score_mod: Callable,
block_mask: Tuple,
scale: float,
kernel_options: Dict[str, Any],
score_mod_other_buffers: Tuple = (),
mask_mod_other_buffers: Tuple = (),
) -> Tuple[torch.Tensor, torch.Tensor]:
assert query.dim() == 4 and key.dim() == 4 and value.dim() == 4
# TODO: Support this if needed; determine if NJT buffers need be unwrapped as dense.
if any(
isinstance(buf, torch.Tensor) and buf.is_nested
for buf in score_mod_other_buffers + mask_mod_other_buffers
):
raise RuntimeError(
"flex_attention(): Nested tensor score_mod / mask_mod buffers are not "
"currently supported. Please file an issue if this is important to you."
)
# need to pass dense tensor of shape (B, n_heads, sum(seq_len), D)
output = flex_attention_hop(
query.values().unsqueeze(0),
key.values().unsqueeze(0),
value.values().unsqueeze(0),
score_mod=score_mod,
block_mask=block_mask,
scale=scale,
kernel_options=kernel_options,
score_mod_other_buffers=score_mod_other_buffers,
mask_mod_other_buffers=mask_mod_other_buffers,
)
# wrap outputs as NJT
output_njt = torch.nested.nested_tensor_from_jagged(
output[0].transpose(1, 2).squeeze(0),
query._offsets, # type: ignore[attr-defined]
query._lengths, # type: ignore[attr-defined]
min_seqlen=query._maybe_min_seqlen, # type: ignore[attr-defined]
max_seqlen=query._maybe_max_seqlen, # type: ignore[attr-defined]
).transpose(1, 2)
logsumexp_njt = torch.nested.nested_tensor_from_jagged(
output[1].transpose(1, 2).squeeze(0),
query._offsets, # type: ignore[attr-defined]
query._lengths, # type: ignore[attr-defined]
min_seqlen=query._maybe_min_seqlen, # type: ignore[attr-defined]
max_seqlen=query._maybe_max_seqlen, # type: ignore[attr-defined]
).transpose(1, 2)
return (output_njt, logsumexp_njt)
@flex_attention_backward_hop.py_impl(NestedTensor) # type: ignore[misc]
def flex_njt_backward(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
out: torch.Tensor,
logsumexp: torch.Tensor,
grad_out: torch.Tensor,
grad_logsumexp: torch.Tensor,
fw_graph: Union[Callable, GraphModule],
joint_graph: GraphModule,
block_mask: Tuple,
scale: float,
kernel_options: Dict[str, Any],
score_mod_other_buffers: Tuple = (),
mask_mod_other_buffers: Tuple = (),
) -> Tuple[
torch.Tensor, torch.Tensor, torch.Tensor, Tuple[Optional[torch.Tensor], ...]
]:
output = flex_attention_backward_hop(
query.values().unsqueeze(0),
key.values().unsqueeze(0),
value.values().unsqueeze(0),
out=out.values().unsqueeze(0),
logsumexp=logsumexp.values().unsqueeze(0),
grad_out=grad_out.values().unsqueeze(0),
grad_logsumexp=grad_logsumexp.values().unsqueeze(0),
fw_graph=fw_graph,
joint_graph=joint_graph,
block_mask=block_mask,
scale=scale,
kernel_options=kernel_options,
score_mod_other_buffers=score_mod_other_buffers,
mask_mod_other_buffers=mask_mod_other_buffers,
)
# wrap grads as NJTs
dense_q_grad, dense_k_grad, dense_v_grad, score_mod_other_buffer_grads = output
njt_q_grad = torch.nested.nested_tensor_from_jagged(
dense_q_grad.transpose(1, 2).squeeze(0),
query._offsets, # type: ignore[attr-defined]
query._lengths, # type: ignore[attr-defined]
min_seqlen=query._maybe_min_seqlen, # type: ignore[attr-defined]
max_seqlen=query._maybe_max_seqlen, # type: ignore[attr-defined]
).transpose(1, 2)
njt_k_grad = torch.nested.nested_tensor_from_jagged(
dense_k_grad.transpose(1, 2).squeeze(0),
key._offsets, # type: ignore[attr-defined]
key._lengths, # type: ignore[attr-defined]
min_seqlen=key._maybe_min_seqlen, # type: ignore[attr-defined]
max_seqlen=key._maybe_max_seqlen, # type: ignore[attr-defined]
).transpose(1, 2)
njt_v_grad = torch.nested.nested_tensor_from_jagged(
dense_v_grad.transpose(1, 2).squeeze(0),
value._offsets, # type: ignore[attr-defined]
value._lengths, # type: ignore[attr-defined]
min_seqlen=value._maybe_min_seqlen, # type: ignore[attr-defined]
max_seqlen=value._maybe_max_seqlen, # type: ignore[attr-defined]
).transpose(1, 2)
return (njt_q_grad, njt_k_grad, njt_v_grad, score_mod_other_buffer_grads)
# Make the dummy available on the C++ side.
@register_jagged_func(torch.ops.aten._nested_get_jagged_dummy.default, "self: any")
def _nested_get_jagged_dummy(func, *args, **kwargs):

View File

@ -29,6 +29,7 @@ __all__ = [
"flex_attention",
"create_block_mask",
"create_mask",
"create_nested_block_mask",
"or_masks",
"and_masks",
"noop_mask",
@ -895,6 +896,130 @@ def _create_empty_block_mask(query: Tensor, key: Tensor) -> BlockMask:
)
def _nested_mod_func_adapter(
orig_mod_func: Union[_score_mod_signature, _mask_mod_signature],
nt: torch.Tensor,
is_score_mod: bool,
) -> Union[_score_mod_signature, _mask_mod_signature]:
r"""Adapter to convert a score_mod / mask_mod to be NJT-compatible. The given mod func
should be written as if operating over a single sequence at a item. This adapter will
handle conversion from indices operating over a "stacked sequence" of length ``sum(S)``
for sequence length ``S`` in the NJT to "sequence relative" indices in range ``[0, S)``.
Args:
orig_mod_func (Callable): Function to modify attention scores. It takes four or five
arguments, depending on whether a mask_mod or score_mod func is passed.
nt (torch.Tensor): Jagged layout nested tensor (NJT) that defines the sequence length
structure for query / key / value.
is_score_mod (bool): Indicates whether the mod function is a score_mod.
Returns:
nt_score_mod: An NJT-compatible version of orig_score_mod
"""
# Used to convert indices within the "stacked" sequence (range [0, sum(*)))
# to "sequence local" indices (range [0, S) for each S).
def _build_seq_idx(offsets, total_length):
range_tensor = torch.arange(
total_length, device=offsets.device, dtype=torch.int32
)
# Use searchsorted to find the index for each position
# NB: This assumes offsets[0] to offsets[-1] spans the packed dim of values.
# If we ever loosen this restriction, this logic will need to be updated.
seq_idx = torch.searchsorted(offsets, range_tensor, right=True) - 1
return seq_idx
offsets = nt._offsets # type: ignore[attr-defined]
total_length = nt._values.shape[nt._ragged_idx - 1] # type: ignore[attr-defined]
seq_idx = _build_seq_idx(offsets, total_length)
# Converts q_idx / kv_idx from [0, total_length) -> [0, S), where S refers
# to the sequence length for each sequence in the NJT, for use in given
# score_mod. This allows the user to write a score_mod as if it were
# operating on a single sequence and the "stacked sequence" is split
# automatically into individual sequences for them.
if is_score_mod:
def nt_score_mod(score, b, h, q_idx, kv_idx):
q_nested = q_idx - offsets[seq_idx[q_idx]]
kv_nested = kv_idx - offsets[seq_idx[kv_idx]]
is_same_sequence = seq_idx[q_idx] == seq_idx[kv_idx]
return torch.where(
is_same_sequence,
orig_mod_func(score, b, h, q_nested, kv_nested), # type: ignore[call-arg]
# don't allow inter-sequence attention
float("-inf"),
)
return nt_score_mod
else:
def nt_mask_mod(b, h, q_idx, kv_idx):
q_nested = q_idx - offsets[seq_idx[q_idx]]
kv_nested = kv_idx - offsets[seq_idx[kv_idx]]
# don't allow inter-sequence attention
is_same_sequence = seq_idx[q_idx] == seq_idx[kv_idx]
return orig_mod_func(b, h, q_nested, kv_nested) & is_same_sequence # type: ignore[call-arg]
return nt_mask_mod
def create_nested_block_mask(
mask_mod: _mask_mod_signature,
B: Optional[int],
H: Optional[int],
nt: torch.Tensor,
BLOCK_SIZE: Union[int, Tuple[int, int]] = _DEFAULT_SPARSE_BLOCK_SIZE,
_compile=False,
) -> BlockMask:
r"""This function creates a nested tensor compatible block mask tuple from a mask_mod
function. The returned BlockMask will be on the device specified by the input nested tensor.
Args:
mask_mod (Callable): mask_mod function. This is a callable that defines the
masking pattern for the attention mechanism. It takes four arguments:
b (batch size), h (number of heads), q_idx (query index), and kv_idx (key/value index).
It should return a boolean tensor indicating which attention connections are allowed
(True) or masked out (False).
B (int): Batch size.
H (int): Number of query heads.
nt (torch.Tensor): Jagged layout nested tensor (NJT) that defines the sequence length
structure for query / key / value. The block mask will be constructed to operate on
a "stacked sequence" of length ``sum(S)`` for sequence length ``S`` from the NJT.
BLOCK_SIZE (int or Tuple[int, int]): Block size for the block mask. If a single int is
provided it is used for both query and key/value.
Returns:
BlockMask: A BlockMask object that contains the block mask information.
Example Usage:
.. code-block:: python
query = torch.randn(1, 1, 8192, 64, device="cuda", dtype=torch.float16)
key = torch.randn(1, 1, 8192, 64, device="cuda", dtype=torch.float16)
value = torch.randn(1, 1, 8192, 64, device="cuda", dtype=torch.float16)
def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx
block_mask = create_nested_block_mask(causal_mask, 1, 1, query, _compile=True)
output = flex_attention(query, key, value, block_mask=block_mask)
"""
return create_block_mask(
_nested_mod_func_adapter(mask_mod, nt, is_score_mod=False), # type: ignore[arg-type]
B,
H,
nt._values.shape[nt._ragged_idx - 1], # type: ignore[attr-defined]
nt._values.shape[nt._ragged_idx - 1], # type: ignore[attr-defined]
device=nt.device, # type: ignore[arg-type]
# compile is important so we don't materialize a mask_tensor of
# shape (1, 1, total_seqlen, total_seqlen)
BLOCK_SIZE=BLOCK_SIZE,
_compile=_compile,
)
def _apply_kernel_options(
query: Tensor, key: Tensor, value: Tensor, return_lse: bool, kernel_options
):
@ -944,6 +1069,25 @@ def _validate_device(query: Tensor, key: Tensor, value: Tensor):
)
def _validate_nestedness(query: Tensor, key: Tensor, value: Tensor):
# Currently, inputs can only be all nested or no nested.
if query.is_nested != key.is_nested or key.is_nested != value.is_nested:
raise ValueError(
"FlexAttention does not support mixed nested tensor / non-nested tensor inputs. "
"Please file an issue requesting this if it is important to you."
)
if (
(query.is_nested and query._lengths is not None) # type: ignore[attr-defined]
or (key.is_nested and key._lengths is not None) # type: ignore[attr-defined]
or (value.is_nested and value._lengths is not None) # type: ignore[attr-defined]
):
raise ValueError(
"FlexAttention does not support nested tensors that are non-contiguous with holes. "
"Please file an issue requesting this if it is important to you."
)
def flex_attention(
query: Tensor,
key: Tensor,
@ -1011,6 +1155,7 @@ def flex_attention(
_validate_sdpa_input(query, key, value)
_validate_embed_dim(query, key, value)
_validate_device(query, key, value)
_validate_nestedness(query, key, value)
if query.dim() != 4 or key.dim() != 4 or value.dim() != 4:
raise NotImplementedError("NYI: query, key, and value must be 4D tensors")
if (not enable_gqa) and query.size(-3) != key.size(-3):
@ -1030,15 +1175,29 @@ def flex_attention(
if score_mod is None:
score_mod = _identity
elif query.is_nested:
score_mod = _nested_mod_func_adapter(score_mod, query, is_score_mod=True) # type: ignore[assignment]
if block_mask is None:
block_mask = _create_empty_block_mask(query, key)
elif (
elif not query.is_nested and (
query.size(-2) < block_mask.kv_num_blocks.size(-1) * block_mask.BLOCK_SIZE[0]
or key.size(-2) < block_mask.kv_indices.size(-1) * block_mask.BLOCK_SIZE[1]
):
new_q_len = _round_up_to_multiple(query.size(-2), block_mask.BLOCK_SIZE[0])
new_kv_len = _round_up_to_multiple(key.size(-2), block_mask.BLOCK_SIZE[1])
block_mask = block_mask._adjust(new_q_len, new_kv_len)
elif query.is_nested and (
block_mask.kv_num_blocks.size(-1) * block_mask.BLOCK_SIZE[0]
!= _round_up_to_multiple(
query._values.size(query._ragged_idx - 1), block_mask.BLOCK_SIZE[0] # type: ignore[attr-defined]
)
):
# TODO: Maybe we want to auto-adjust for this case as well?
raise RuntimeError(
f"block_mask of shape {block_mask.shape} is not compatible with nested tensor input "
f"with total sequence length of {query._values.size(query._ragged_idx - 1)}" # type: ignore[attr-defined]
)
if scale is None:
scale = 1.0 / math.sqrt(query.size(-1))

View File

@ -1950,3 +1950,12 @@ def skipPRIVATEUSE1(fn):
# This should probably enumerate all available device type test base classes.
def get_all_device_types() -> List[str]:
return ["cpu"] if not torch.cuda.is_available() else ["cpu", "cuda"]
flex_attention_supported_platform = unittest.skipUnless(
torch.cuda.is_available()
and torch.version.hip is None
and torch.utils._triton.has_triton()
and torch.cuda.get_device_capability() >= (8, 0),
"Requires CUDA and Triton",
)