mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Make FlexAttention API public (#130755)
# Summary Makes the prototype API flex_attention public Pull Request resolved: https://github.com/pytorch/pytorch/pull/130755 Approved by: https://github.com/Chillee
This commit is contained in:
committed by
PyTorch MergeBot
parent
cbda8be537
commit
2b43d339fe
23
docs/source/nn.attention.flex_attention.rst
Normal file
23
docs/source/nn.attention.flex_attention.rst
Normal file
@ -0,0 +1,23 @@
|
||||
.. role:: hidden
|
||||
:class: hidden-section
|
||||
|
||||
======================================
|
||||
torch.nn.attention.flex_attention
|
||||
======================================
|
||||
|
||||
.. currentmodule:: torch.nn.attention.flex_attention
|
||||
.. py:module:: torch.nn.attention.flex_attention
|
||||
.. autofunction:: flex_attention
|
||||
|
||||
BlockMask Utilities
|
||||
-------------------
|
||||
|
||||
.. autofunction:: create_block_mask
|
||||
.. autofunction:: create_mask
|
||||
|
||||
BlockMask
|
||||
---------
|
||||
|
||||
.. autoclass:: BlockMask
|
||||
:members:
|
||||
:undoc-members:
|
||||
@ -20,9 +20,11 @@ Submodules
|
||||
.. autosummary::
|
||||
:nosignatures:
|
||||
|
||||
flex_attention
|
||||
bias
|
||||
|
||||
.. toctree::
|
||||
:hidden:
|
||||
|
||||
nn.attention.flex_attention
|
||||
nn.attention.bias
|
||||
@ -16,7 +16,7 @@ from torch._higher_order_ops.flex_attention import flex_attention as flex_attent
|
||||
from torch._inductor import metrics
|
||||
from torch._inductor.test_case import TestCase as InductorTestCase
|
||||
from torch._inductor.utils import run_and_get_code
|
||||
from torch.nn.attention._flex_attention import (
|
||||
from torch.nn.attention.flex_attention import (
|
||||
_causal,
|
||||
_compose,
|
||||
_create_empty_block_mask,
|
||||
|
||||
@ -13,7 +13,7 @@ import torch
|
||||
from torch._higher_order_ops.flex_attention import flex_attention as flex_attention_hop
|
||||
from torch._inductor.test_case import TestCase as InductorTestCase
|
||||
from torch._inductor.utils import run_and_get_code
|
||||
from torch.nn.attention._flex_attention import (
|
||||
from torch.nn.attention.flex_attention import (
|
||||
_causal,
|
||||
_compose,
|
||||
_create_empty_block_mask,
|
||||
|
||||
@ -136,7 +136,7 @@ def _math_attention_inner(
|
||||
n = torch.arange(0, scores.size(3), device=scores.device)
|
||||
|
||||
captured_buffers_in_dim = (None,) * len(score_mod_other_buffers)
|
||||
from torch.nn.attention._flex_attention import _vmap_for_bhqkv
|
||||
from torch.nn.attention.flex_attention import _vmap_for_bhqkv
|
||||
|
||||
# first input is score
|
||||
score_mod = _vmap_for_bhqkv(score_mod, prefix=(0,), suffix=captured_buffers_in_dim)
|
||||
@ -690,7 +690,7 @@ def sdpa_dense_backward(
|
||||
# Gradient of the inline score_mod function, with respect to the scores
|
||||
captured_buffers_in_dim = (None,) * len(score_mod_other_buffers)
|
||||
out_dims = [0, None, None, None, None] + [None] * len(score_mod_other_buffers)
|
||||
from torch.nn.attention._flex_attention import _vmap_for_bhqkv
|
||||
from torch.nn.attention.flex_attention import _vmap_for_bhqkv
|
||||
|
||||
# inputs are [score, b, h, q_idx, kv_idx, gradOut, ...]
|
||||
# score and gradOut are "fully" batched
|
||||
|
||||
@ -42,13 +42,25 @@ _score_mod_signature = Callable[[Tensor, Tensor, Tensor, Tensor, Tensor], Tensor
|
||||
_mask_fn_signature = Callable[[Tensor, Tensor, Tensor, Tensor], Tensor]
|
||||
|
||||
|
||||
class ModificationType(Enum):
|
||||
class _ModificationType(Enum):
|
||||
"""Enum for the type of modification function.
|
||||
- SCORE_MOD: score_mod function which accepts a score as the first argument
|
||||
- MASK_FN: mask function which does not accept a score and is only used for generating
|
||||
block mask
|
||||
"""
|
||||
|
||||
SCORE_MOD = 1
|
||||
MASK_FN = 2
|
||||
|
||||
|
||||
@torch._dynamo.assume_constant_result
|
||||
def get_mod_type(fn) -> ModificationType:
|
||||
def _get_mod_type(fn: Callable) -> _ModificationType:
|
||||
"""Get the type of modification function.
|
||||
This function inspects the number of positional arguments of the function to determine
|
||||
the type of modification function. If the function has 5 positional arguments, it is
|
||||
considered as a score_mod function. If the function has 4 positional arguments, it is
|
||||
considered as a mask function.
|
||||
"""
|
||||
num_positional_args = sum(
|
||||
1
|
||||
for param in inspect.signature(fn).parameters.values()
|
||||
@ -56,9 +68,9 @@ def get_mod_type(fn) -> ModificationType:
|
||||
)
|
||||
assert num_positional_args == 5 or num_positional_args == 4
|
||||
if num_positional_args == 5:
|
||||
return ModificationType.SCORE_MOD
|
||||
return _ModificationType.SCORE_MOD
|
||||
elif num_positional_args == 4:
|
||||
return ModificationType.MASK_FN
|
||||
return _ModificationType.MASK_FN
|
||||
else:
|
||||
raise AssertionError
|
||||
|
||||
@ -114,51 +126,59 @@ class BlockMask:
|
||||
BlockMask is our format for representing a block-sparse attention mask.
|
||||
It is somewhat of a cross in-between BCSR and a non-sparse format.
|
||||
|
||||
## Basics
|
||||
Basics
|
||||
------
|
||||
A block-sparse mask means that instead of representing the sparsity of
|
||||
individual elements in the mask, we only consider a block sparse if an
|
||||
entire KV_BLOCK_SIZE x Q_BLOCK_SIZE is sparse. This aligns well with
|
||||
hardware, which generally expects to perform contiguous loads and
|
||||
computation.
|
||||
individual elements in the mask, a KV_BLOCK_SIZE x Q_BLOCK_SIZE block is
|
||||
considered sparse only if every element within that block is sparse.
|
||||
This aligns well with hardware, which generally expects to perform
|
||||
contiguous loads and computation.
|
||||
|
||||
This format is primarily optimized for 1. simplicity, and 2. kernel
|
||||
efficiency. Notably, it is *not* optimized for size, as we believe the mask
|
||||
is sufficiently small that its size is not a concern.
|
||||
|
||||
The essentials of our format are:
|
||||
num_blocks_in_row: Tensor[ROWS] # Describes the number of blocks present in
|
||||
each row.
|
||||
col_indices: Tensor[ROWS, MAX_BLOCKS_IN_COL] # col_indices[i] is the
|
||||
position of the blocks in index i. The values of this row after
|
||||
col_indices[i][num_blocks_in_row[i]] are undefined.
|
||||
|
||||
For example, to reconstruct the original tensor from this format.
|
||||
```
|
||||
dense_mask = torch.zeros(ROWS, COLS)
|
||||
for row in range(ROWS):
|
||||
for block_idx in range(num_blocks_in_row[row]):
|
||||
dense_mask[row, col_indices[row, block_idx]] = 1
|
||||
```
|
||||
- num_blocks_in_row: Tensor[ROWS]
|
||||
Describes the number of blocks present in each row.
|
||||
|
||||
- col_indices: Tensor[ROWS, MAX_BLOCKS_IN_COL]
|
||||
`col_indices[i]` is the sequence of block positions for row i. The values of
|
||||
this row after `col_indices[i][num_blocks_in_row[i]]` are undefined.
|
||||
|
||||
For example, to reconstruct the original tensor from this format:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
dense_mask = torch.zeros(ROWS, COLS)
|
||||
for row in range(ROWS):
|
||||
for block_idx in range(num_blocks_in_row[row]):
|
||||
dense_mask[row, col_indices[row, block_idx]] = 1
|
||||
|
||||
Notably, this format makes it easier to implement a reduction along the
|
||||
*rows* of the mask.
|
||||
|
||||
## Details
|
||||
The basics of our format require only kv_num_blocks and kv_indices. But, we have up to 8 tensors on this object. This represents 4 pairs:
|
||||
Details
|
||||
-------
|
||||
The basics of our format require only kv_num_blocks and kv_indices. But, we
|
||||
have up to 8 tensors on this object. This represents 4 pairs:
|
||||
|
||||
(kv_num_blocks, kv_indices): This is used for the forwards pass of
|
||||
attention, as we reduce along the KV dimension.
|
||||
(q_num_blocks, q_indices): This is required for the backwards pass, as
|
||||
computing dKV requires iterating along the mask along the Q dimension.
|
||||
[OPTIONAL](full_kv_num_blocks, full_kv_indices): This is optional, and is
|
||||
purely an optimization. As it turns out, applying masking to every block is
|
||||
quite expensive! If we specifically know which blocks are "full" and don't
|
||||
require masking at all, then we can skip applying mask_mod to these blocks.
|
||||
This requires the user to split out a separate mask_mod from the score_mod.
|
||||
For causal masks, this is about a 15% speedup.
|
||||
[OPTIONAL](full_q_num_blocks, full_q_indices): Same as above, but for the
|
||||
backwards.
|
||||
1. (kv_num_blocks, kv_indices): Used for the forwards pass of attention, as
|
||||
we reduce along the KV dimension.
|
||||
|
||||
2. (q_num_blocks, q_indices): Required for the backwards pass, as computing
|
||||
dKV requires iterating along the mask along the Q dimension.
|
||||
|
||||
3. [OPTIONAL] (full_kv_num_blocks, full_kv_indices): This is optional and
|
||||
purely an optimization. As it turns out, applying masking to every block
|
||||
is quite expensive! If we specifically know which blocks are "full" and
|
||||
don't require masking at all, then we can skip applying mask_mod to these
|
||||
blocks. This requires the user to split out a separate mask_mod from the
|
||||
score_mod. For causal masks, this is about a 15% speedup.
|
||||
|
||||
4. [OPTIONAL] (full_q_num_blocks, full_q_indices): Same as above, but for
|
||||
the backwards pass.
|
||||
"""
|
||||
kv_num_blocks: Tensor
|
||||
kv_indices: Tensor
|
||||
@ -184,7 +204,7 @@ class BlockMask:
|
||||
full_q_indices: Optional[Tensor],
|
||||
KV_BLOCK_SIZE=_DEFAULT_SPARSE_BLOCK_SIZE,
|
||||
Q_BLOCK_SIZE=_DEFAULT_SPARSE_BLOCK_SIZE,
|
||||
mask_fn=None,
|
||||
mask_fn: Optional[_mask_fn_signature] = None,
|
||||
):
|
||||
if kv_indices.dim() < 2:
|
||||
raise RuntimeError("BlockMask must have at least 2 dimensions")
|
||||
@ -469,7 +489,7 @@ def create_mask(
|
||||
r"""This function creates a mask tensor from a mod_fn function.
|
||||
|
||||
Args:
|
||||
mod_fn (Callable): Function to modify attention scores.
|
||||
mod_fn (Union[_score_mod_signature, _mask_fn_signature]): Function to modify attention scores.
|
||||
B (int): Batch size.
|
||||
H (int): Number of heads.
|
||||
M (int): Sequence length of query.
|
||||
@ -491,16 +511,16 @@ def create_mask(
|
||||
ctx = nullcontext()
|
||||
else:
|
||||
ctx = TransformGetItemToIndex() # type: ignore[assignment]
|
||||
mod_type = get_mod_type(mod_fn)
|
||||
mod_type = _get_mod_type(mod_fn)
|
||||
|
||||
with ctx:
|
||||
if mod_type == ModificationType.SCORE_MOD:
|
||||
if mod_type == _ModificationType.SCORE_MOD:
|
||||
score_mod = mod_fn
|
||||
score_mod = _vmap_for_bhqkv(score_mod, prefix=(0,)) # first input is score
|
||||
out = score_mod(torch.zeros(B, H, M, N, device=device), b, h, m, n)
|
||||
mask = torch.where(torch.isneginf(out), False, True)
|
||||
return mask
|
||||
elif mod_type == ModificationType.MASK_FN:
|
||||
elif mod_type == _ModificationType.MASK_FN:
|
||||
mask_fn = mod_fn
|
||||
mask_fn = _vmap_for_bhqkv(mask_fn, prefix=())
|
||||
mask = mask_fn(b, h, m, n)
|
||||
@ -515,8 +535,8 @@ def _create_block_mask_inner(
|
||||
mod_fn, B, H, M, N, device, KV_BLOCK_SIZE, Q_BLOCK_SIZE, mod_type
|
||||
):
|
||||
mask_tensor = create_mask(mod_fn, B, H, M, N, device, _compile=True)
|
||||
mod_type = get_mod_type(mod_fn)
|
||||
if mod_type == ModificationType.MASK_FN:
|
||||
mod_type = _get_mod_type(mod_fn)
|
||||
if mod_type == _ModificationType.MASK_FN:
|
||||
mask_fn = mod_fn
|
||||
else:
|
||||
mask_fn = None
|
||||
@ -558,7 +578,7 @@ def create_block_mask(
|
||||
block_mask (tuple): A tuple of (kv_num_blocks, kv_indices, q_num_blocks, q_indices,
|
||||
KV_BLOCK_SIZE, Q_BLOCK_SIZE) which represents the block mask.
|
||||
"""
|
||||
mod_type = get_mod_type(fn)
|
||||
mod_type = _get_mod_type(fn)
|
||||
inner_func = _create_block_mask_inner
|
||||
# This is kind of a temporary hack to workaround some issues
|
||||
if _compile:
|
||||
@ -618,14 +638,14 @@ def flex_attention(
|
||||
score: Tensor,
|
||||
batch: Tensor,
|
||||
head: Tensor,
|
||||
token_q: Tensor,
|
||||
token_kv: Tensor
|
||||
q_idx: Tensor,
|
||||
kv_idx: Tensor
|
||||
) -> Tensor:
|
||||
|
||||
Where:
|
||||
- ``score``: A scalar tensor representing the attention score,
|
||||
with the same data type and device as the query, key, and value tensors.
|
||||
- ``b``, ``h``, ``q_idx``, ``kv_idx``: Scalar tensors indicating
|
||||
- ``batch``, ``head``, ``q_idx``, ``kv_idx``: Scalar tensors indicating
|
||||
the batch index, head index, query index, and key/value index, respectively.
|
||||
These should have the ``torch.int`` data type and be located on the same device as the score tensor.
|
||||
|
||||
@ -11,7 +11,7 @@ from torch.testing._internal.opinfo.core import (
|
||||
)
|
||||
from torch.testing._internal.common_dtype import all_types_and, custom_types
|
||||
from torch.testing._internal.opinfo.core import DecorateInfo
|
||||
from torch.nn.attention._flex_attention import flex_attention, _create_empty_block_mask
|
||||
from torch.nn.attention.flex_attention import flex_attention, _create_empty_block_mask
|
||||
|
||||
def sample_inputs_map(opinfo, device, dtype, requires_grad, **kwargs):
|
||||
make_arg = functools.partial(
|
||||
|
||||
Reference in New Issue
Block a user