Make FlexAttention API public (#130755)

# Summary

Makes the prototype API flex_attention public

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130755
Approved by: https://github.com/Chillee
This commit is contained in:
drisspg
2024-07-16 16:21:23 +00:00
committed by PyTorch MergeBot
parent cbda8be537
commit 2b43d339fe
7 changed files with 96 additions and 51 deletions

View File

@ -0,0 +1,23 @@
.. role:: hidden
:class: hidden-section
======================================
torch.nn.attention.flex_attention
======================================
.. currentmodule:: torch.nn.attention.flex_attention
.. py:module:: torch.nn.attention.flex_attention
.. autofunction:: flex_attention
BlockMask Utilities
-------------------
.. autofunction:: create_block_mask
.. autofunction:: create_mask
BlockMask
---------
.. autoclass:: BlockMask
:members:
:undoc-members:

View File

@ -20,9 +20,11 @@ Submodules
.. autosummary::
:nosignatures:
flex_attention
bias
.. toctree::
:hidden:
nn.attention.flex_attention
nn.attention.bias

View File

@ -16,7 +16,7 @@ from torch._higher_order_ops.flex_attention import flex_attention as flex_attent
from torch._inductor import metrics
from torch._inductor.test_case import TestCase as InductorTestCase
from torch._inductor.utils import run_and_get_code
from torch.nn.attention._flex_attention import (
from torch.nn.attention.flex_attention import (
_causal,
_compose,
_create_empty_block_mask,

View File

@ -13,7 +13,7 @@ import torch
from torch._higher_order_ops.flex_attention import flex_attention as flex_attention_hop
from torch._inductor.test_case import TestCase as InductorTestCase
from torch._inductor.utils import run_and_get_code
from torch.nn.attention._flex_attention import (
from torch.nn.attention.flex_attention import (
_causal,
_compose,
_create_empty_block_mask,

View File

@ -136,7 +136,7 @@ def _math_attention_inner(
n = torch.arange(0, scores.size(3), device=scores.device)
captured_buffers_in_dim = (None,) * len(score_mod_other_buffers)
from torch.nn.attention._flex_attention import _vmap_for_bhqkv
from torch.nn.attention.flex_attention import _vmap_for_bhqkv
# first input is score
score_mod = _vmap_for_bhqkv(score_mod, prefix=(0,), suffix=captured_buffers_in_dim)
@ -690,7 +690,7 @@ def sdpa_dense_backward(
# Gradient of the inline score_mod function, with respect to the scores
captured_buffers_in_dim = (None,) * len(score_mod_other_buffers)
out_dims = [0, None, None, None, None] + [None] * len(score_mod_other_buffers)
from torch.nn.attention._flex_attention import _vmap_for_bhqkv
from torch.nn.attention.flex_attention import _vmap_for_bhqkv
# inputs are [score, b, h, q_idx, kv_idx, gradOut, ...]
# score and gradOut are "fully" batched

View File

@ -42,13 +42,25 @@ _score_mod_signature = Callable[[Tensor, Tensor, Tensor, Tensor, Tensor], Tensor
_mask_fn_signature = Callable[[Tensor, Tensor, Tensor, Tensor], Tensor]
class ModificationType(Enum):
class _ModificationType(Enum):
"""Enum for the type of modification function.
- SCORE_MOD: score_mod function which accepts a score as the first argument
- MASK_FN: mask function which does not accept a score and is only used for generating
block mask
"""
SCORE_MOD = 1
MASK_FN = 2
@torch._dynamo.assume_constant_result
def get_mod_type(fn) -> ModificationType:
def _get_mod_type(fn: Callable) -> _ModificationType:
"""Get the type of modification function.
This function inspects the number of positional arguments of the function to determine
the type of modification function. If the function has 5 positional arguments, it is
considered as a score_mod function. If the function has 4 positional arguments, it is
considered as a mask function.
"""
num_positional_args = sum(
1
for param in inspect.signature(fn).parameters.values()
@ -56,9 +68,9 @@ def get_mod_type(fn) -> ModificationType:
)
assert num_positional_args == 5 or num_positional_args == 4
if num_positional_args == 5:
return ModificationType.SCORE_MOD
return _ModificationType.SCORE_MOD
elif num_positional_args == 4:
return ModificationType.MASK_FN
return _ModificationType.MASK_FN
else:
raise AssertionError
@ -114,51 +126,59 @@ class BlockMask:
BlockMask is our format for representing a block-sparse attention mask.
It is somewhat of a cross in-between BCSR and a non-sparse format.
## Basics
Basics
------
A block-sparse mask means that instead of representing the sparsity of
individual elements in the mask, we only consider a block sparse if an
entire KV_BLOCK_SIZE x Q_BLOCK_SIZE is sparse. This aligns well with
hardware, which generally expects to perform contiguous loads and
computation.
individual elements in the mask, a KV_BLOCK_SIZE x Q_BLOCK_SIZE block is
considered sparse only if every element within that block is sparse.
This aligns well with hardware, which generally expects to perform
contiguous loads and computation.
This format is primarily optimized for 1. simplicity, and 2. kernel
efficiency. Notably, it is *not* optimized for size, as we believe the mask
is sufficiently small that its size is not a concern.
The essentials of our format are:
num_blocks_in_row: Tensor[ROWS] # Describes the number of blocks present in
each row.
col_indices: Tensor[ROWS, MAX_BLOCKS_IN_COL] # col_indices[i] is the
position of the blocks in index i. The values of this row after
col_indices[i][num_blocks_in_row[i]] are undefined.
For example, to reconstruct the original tensor from this format.
```
dense_mask = torch.zeros(ROWS, COLS)
for row in range(ROWS):
for block_idx in range(num_blocks_in_row[row]):
dense_mask[row, col_indices[row, block_idx]] = 1
```
- num_blocks_in_row: Tensor[ROWS]
Describes the number of blocks present in each row.
- col_indices: Tensor[ROWS, MAX_BLOCKS_IN_COL]
`col_indices[i]` is the sequence of block positions for row i. The values of
this row after `col_indices[i][num_blocks_in_row[i]]` are undefined.
For example, to reconstruct the original tensor from this format:
.. code-block:: python
dense_mask = torch.zeros(ROWS, COLS)
for row in range(ROWS):
for block_idx in range(num_blocks_in_row[row]):
dense_mask[row, col_indices[row, block_idx]] = 1
Notably, this format makes it easier to implement a reduction along the
*rows* of the mask.
## Details
The basics of our format require only kv_num_blocks and kv_indices. But, we have up to 8 tensors on this object. This represents 4 pairs:
Details
-------
The basics of our format require only kv_num_blocks and kv_indices. But, we
have up to 8 tensors on this object. This represents 4 pairs:
(kv_num_blocks, kv_indices): This is used for the forwards pass of
attention, as we reduce along the KV dimension.
(q_num_blocks, q_indices): This is required for the backwards pass, as
computing dKV requires iterating along the mask along the Q dimension.
[OPTIONAL](full_kv_num_blocks, full_kv_indices): This is optional, and is
purely an optimization. As it turns out, applying masking to every block is
quite expensive! If we specifically know which blocks are "full" and don't
require masking at all, then we can skip applying mask_mod to these blocks.
This requires the user to split out a separate mask_mod from the score_mod.
For causal masks, this is about a 15% speedup.
[OPTIONAL](full_q_num_blocks, full_q_indices): Same as above, but for the
backwards.
1. (kv_num_blocks, kv_indices): Used for the forwards pass of attention, as
we reduce along the KV dimension.
2. (q_num_blocks, q_indices): Required for the backwards pass, as computing
dKV requires iterating along the mask along the Q dimension.
3. [OPTIONAL] (full_kv_num_blocks, full_kv_indices): This is optional and
purely an optimization. As it turns out, applying masking to every block
is quite expensive! If we specifically know which blocks are "full" and
don't require masking at all, then we can skip applying mask_mod to these
blocks. This requires the user to split out a separate mask_mod from the
score_mod. For causal masks, this is about a 15% speedup.
4. [OPTIONAL] (full_q_num_blocks, full_q_indices): Same as above, but for
the backwards pass.
"""
kv_num_blocks: Tensor
kv_indices: Tensor
@ -184,7 +204,7 @@ class BlockMask:
full_q_indices: Optional[Tensor],
KV_BLOCK_SIZE=_DEFAULT_SPARSE_BLOCK_SIZE,
Q_BLOCK_SIZE=_DEFAULT_SPARSE_BLOCK_SIZE,
mask_fn=None,
mask_fn: Optional[_mask_fn_signature] = None,
):
if kv_indices.dim() < 2:
raise RuntimeError("BlockMask must have at least 2 dimensions")
@ -469,7 +489,7 @@ def create_mask(
r"""This function creates a mask tensor from a mod_fn function.
Args:
mod_fn (Callable): Function to modify attention scores.
mod_fn (Union[_score_mod_signature, _mask_fn_signature]): Function to modify attention scores.
B (int): Batch size.
H (int): Number of heads.
M (int): Sequence length of query.
@ -491,16 +511,16 @@ def create_mask(
ctx = nullcontext()
else:
ctx = TransformGetItemToIndex() # type: ignore[assignment]
mod_type = get_mod_type(mod_fn)
mod_type = _get_mod_type(mod_fn)
with ctx:
if mod_type == ModificationType.SCORE_MOD:
if mod_type == _ModificationType.SCORE_MOD:
score_mod = mod_fn
score_mod = _vmap_for_bhqkv(score_mod, prefix=(0,)) # first input is score
out = score_mod(torch.zeros(B, H, M, N, device=device), b, h, m, n)
mask = torch.where(torch.isneginf(out), False, True)
return mask
elif mod_type == ModificationType.MASK_FN:
elif mod_type == _ModificationType.MASK_FN:
mask_fn = mod_fn
mask_fn = _vmap_for_bhqkv(mask_fn, prefix=())
mask = mask_fn(b, h, m, n)
@ -515,8 +535,8 @@ def _create_block_mask_inner(
mod_fn, B, H, M, N, device, KV_BLOCK_SIZE, Q_BLOCK_SIZE, mod_type
):
mask_tensor = create_mask(mod_fn, B, H, M, N, device, _compile=True)
mod_type = get_mod_type(mod_fn)
if mod_type == ModificationType.MASK_FN:
mod_type = _get_mod_type(mod_fn)
if mod_type == _ModificationType.MASK_FN:
mask_fn = mod_fn
else:
mask_fn = None
@ -558,7 +578,7 @@ def create_block_mask(
block_mask (tuple): A tuple of (kv_num_blocks, kv_indices, q_num_blocks, q_indices,
KV_BLOCK_SIZE, Q_BLOCK_SIZE) which represents the block mask.
"""
mod_type = get_mod_type(fn)
mod_type = _get_mod_type(fn)
inner_func = _create_block_mask_inner
# This is kind of a temporary hack to workaround some issues
if _compile:
@ -618,14 +638,14 @@ def flex_attention(
score: Tensor,
batch: Tensor,
head: Tensor,
token_q: Tensor,
token_kv: Tensor
q_idx: Tensor,
kv_idx: Tensor
) -> Tensor:
Where:
- ``score``: A scalar tensor representing the attention score,
with the same data type and device as the query, key, and value tensors.
- ``b``, ``h``, ``q_idx``, ``kv_idx``: Scalar tensors indicating
- ``batch``, ``head``, ``q_idx``, ``kv_idx``: Scalar tensors indicating
the batch index, head index, query index, and key/value index, respectively.
These should have the ``torch.int`` data type and be located on the same device as the score tensor.

View File

@ -11,7 +11,7 @@ from torch.testing._internal.opinfo.core import (
)
from torch.testing._internal.common_dtype import all_types_and, custom_types
from torch.testing._internal.opinfo.core import DecorateInfo
from torch.nn.attention._flex_attention import flex_attention, _create_empty_block_mask
from torch.nn.attention.flex_attention import flex_attention, _create_empty_block_mask
def sample_inputs_map(opinfo, device, dtype, requires_grad, **kwargs):
make_arg = functools.partial(