mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	Make FlexAttention API public (#130755)
# Summary Makes the prototype API flex_attention public Pull Request resolved: https://github.com/pytorch/pytorch/pull/130755 Approved by: https://github.com/Chillee
This commit is contained in:
		
				
					committed by
					
						 PyTorch MergeBot
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							cbda8be537
						
					
				
				
					commit
					2b43d339fe
				
			
							
								
								
									
										23
									
								
								docs/source/nn.attention.flex_attention.rst
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										23
									
								
								docs/source/nn.attention.flex_attention.rst
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,23 @@ | ||||
| .. role:: hidden | ||||
|     :class: hidden-section | ||||
|  | ||||
| ====================================== | ||||
| torch.nn.attention.flex_attention | ||||
| ====================================== | ||||
|  | ||||
| .. currentmodule:: torch.nn.attention.flex_attention | ||||
| .. py:module:: torch.nn.attention.flex_attention | ||||
| .. autofunction:: flex_attention | ||||
|  | ||||
| BlockMask Utilities | ||||
| ------------------- | ||||
|  | ||||
| .. autofunction:: create_block_mask | ||||
| .. autofunction:: create_mask | ||||
|  | ||||
| BlockMask | ||||
| --------- | ||||
|  | ||||
| .. autoclass:: BlockMask | ||||
|     :members: | ||||
|     :undoc-members: | ||||
| @ -20,9 +20,11 @@ Submodules | ||||
| .. autosummary:: | ||||
|     :nosignatures: | ||||
|  | ||||
|     flex_attention | ||||
|     bias | ||||
|  | ||||
| .. toctree:: | ||||
|     :hidden: | ||||
|  | ||||
|     nn.attention.flex_attention | ||||
|     nn.attention.bias | ||||
| @ -16,7 +16,7 @@ from torch._higher_order_ops.flex_attention import flex_attention as flex_attent | ||||
| from torch._inductor import metrics | ||||
| from torch._inductor.test_case import TestCase as InductorTestCase | ||||
| from torch._inductor.utils import run_and_get_code | ||||
| from torch.nn.attention._flex_attention import ( | ||||
| from torch.nn.attention.flex_attention import ( | ||||
|     _causal, | ||||
|     _compose, | ||||
|     _create_empty_block_mask, | ||||
|  | ||||
| @ -13,7 +13,7 @@ import torch | ||||
| from torch._higher_order_ops.flex_attention import flex_attention as flex_attention_hop | ||||
| from torch._inductor.test_case import TestCase as InductorTestCase | ||||
| from torch._inductor.utils import run_and_get_code | ||||
| from torch.nn.attention._flex_attention import ( | ||||
| from torch.nn.attention.flex_attention import ( | ||||
|     _causal, | ||||
|     _compose, | ||||
|     _create_empty_block_mask, | ||||
|  | ||||
| @ -136,7 +136,7 @@ def _math_attention_inner( | ||||
|     n = torch.arange(0, scores.size(3), device=scores.device) | ||||
|  | ||||
|     captured_buffers_in_dim = (None,) * len(score_mod_other_buffers) | ||||
|     from torch.nn.attention._flex_attention import _vmap_for_bhqkv | ||||
|     from torch.nn.attention.flex_attention import _vmap_for_bhqkv | ||||
|  | ||||
|     # first input is score | ||||
|     score_mod = _vmap_for_bhqkv(score_mod, prefix=(0,), suffix=captured_buffers_in_dim) | ||||
| @ -690,7 +690,7 @@ def sdpa_dense_backward( | ||||
|     # Gradient of the inline score_mod function, with respect to the scores | ||||
|     captured_buffers_in_dim = (None,) * len(score_mod_other_buffers) | ||||
|     out_dims = [0, None, None, None, None] + [None] * len(score_mod_other_buffers) | ||||
|     from torch.nn.attention._flex_attention import _vmap_for_bhqkv | ||||
|     from torch.nn.attention.flex_attention import _vmap_for_bhqkv | ||||
|  | ||||
|     # inputs are [score, b, h, q_idx, kv_idx, gradOut, ...] | ||||
|     # score and gradOut are "fully" batched | ||||
|  | ||||
| @ -42,13 +42,25 @@ _score_mod_signature = Callable[[Tensor, Tensor, Tensor, Tensor, Tensor], Tensor | ||||
| _mask_fn_signature = Callable[[Tensor, Tensor, Tensor, Tensor], Tensor] | ||||
| 
 | ||||
| 
 | ||||
| class ModificationType(Enum): | ||||
| class _ModificationType(Enum): | ||||
|     """Enum for the type of modification function. | ||||
|     - SCORE_MOD: score_mod function which accepts a score as the first argument | ||||
|     - MASK_FN: mask function which does not accept a score and is only used for generating | ||||
|     block mask | ||||
|     """ | ||||
| 
 | ||||
|     SCORE_MOD = 1 | ||||
|     MASK_FN = 2 | ||||
| 
 | ||||
| 
 | ||||
| @torch._dynamo.assume_constant_result | ||||
| def get_mod_type(fn) -> ModificationType: | ||||
| def _get_mod_type(fn: Callable) -> _ModificationType: | ||||
|     """Get the type of modification function. | ||||
|     This function inspects the number of positional arguments of the function to determine | ||||
|     the type of modification function. If the function has 5 positional arguments, it is | ||||
|     considered as a score_mod function. If the function has 4 positional arguments, it is | ||||
|     considered as a mask function. | ||||
|     """ | ||||
|     num_positional_args = sum( | ||||
|         1 | ||||
|         for param in inspect.signature(fn).parameters.values() | ||||
| @ -56,9 +68,9 @@ def get_mod_type(fn) -> ModificationType: | ||||
|     ) | ||||
|     assert num_positional_args == 5 or num_positional_args == 4 | ||||
|     if num_positional_args == 5: | ||||
|         return ModificationType.SCORE_MOD | ||||
|         return _ModificationType.SCORE_MOD | ||||
|     elif num_positional_args == 4: | ||||
|         return ModificationType.MASK_FN | ||||
|         return _ModificationType.MASK_FN | ||||
|     else: | ||||
|         raise AssertionError | ||||
| 
 | ||||
| @ -114,51 +126,59 @@ class BlockMask: | ||||
|     BlockMask is our format for representing a block-sparse attention mask. | ||||
|     It is somewhat of a cross in-between BCSR and a non-sparse format. | ||||
| 
 | ||||
|     ## Basics | ||||
|     Basics | ||||
|     ------ | ||||
|     A block-sparse mask means that instead of representing the sparsity of | ||||
|     individual elements in the mask, we only consider a block sparse if an | ||||
|     entire KV_BLOCK_SIZE x Q_BLOCK_SIZE is sparse. This aligns well with | ||||
|     hardware, which generally expects to perform contiguous loads and | ||||
|     computation. | ||||
|     individual elements in the mask, a KV_BLOCK_SIZE x Q_BLOCK_SIZE block is | ||||
|     considered sparse only if every element within that block is sparse. | ||||
|     This aligns well with hardware, which generally expects to perform | ||||
|     contiguous loads and computation. | ||||
| 
 | ||||
|     This format is primarily optimized for 1. simplicity, and 2. kernel | ||||
|     efficiency. Notably, it is *not* optimized for size, as we believe the mask | ||||
|     is sufficiently small that its size is not a concern. | ||||
| 
 | ||||
|     The essentials of our format are: | ||||
|     num_blocks_in_row: Tensor[ROWS]  # Describes the number of blocks present in | ||||
|     each row. | ||||
|     col_indices: Tensor[ROWS, MAX_BLOCKS_IN_COL]  # col_indices[i] is the | ||||
|     position of the blocks in index i. The values of this row after | ||||
|     col_indices[i][num_blocks_in_row[i]] are undefined. | ||||
| 
 | ||||
|     For example, to reconstruct the original tensor from this format. | ||||
|     ``` | ||||
|     dense_mask = torch.zeros(ROWS, COLS) | ||||
|     for row in range(ROWS): | ||||
|         for block_idx in range(num_blocks_in_row[row]): | ||||
|             dense_mask[row, col_indices[row, block_idx]] = 1 | ||||
|     ``` | ||||
|     - num_blocks_in_row: Tensor[ROWS] | ||||
|         Describes the number of blocks present in each row. | ||||
| 
 | ||||
|     - col_indices: Tensor[ROWS, MAX_BLOCKS_IN_COL] | ||||
|         `col_indices[i]` is the sequence of block positions for row i. The values of | ||||
|         this row after `col_indices[i][num_blocks_in_row[i]]` are undefined. | ||||
| 
 | ||||
|     For example, to reconstruct the original tensor from this format: | ||||
| 
 | ||||
|     .. code-block:: python | ||||
| 
 | ||||
|         dense_mask = torch.zeros(ROWS, COLS) | ||||
|         for row in range(ROWS): | ||||
|             for block_idx in range(num_blocks_in_row[row]): | ||||
|                 dense_mask[row, col_indices[row, block_idx]] = 1 | ||||
| 
 | ||||
|     Notably, this format makes it easier to implement a reduction along the | ||||
|     *rows* of the mask. | ||||
| 
 | ||||
|     ## Details | ||||
|     The basics of our format require only kv_num_blocks and kv_indices. But, we have up to 8 tensors on this object. This represents 4 pairs: | ||||
|     Details | ||||
|     ------- | ||||
|     The basics of our format require only kv_num_blocks and kv_indices. But, we | ||||
|     have up to 8 tensors on this object. This represents 4 pairs: | ||||
| 
 | ||||
|     (kv_num_blocks, kv_indices): This is used for the forwards pass of | ||||
|         attention, as we reduce along the KV dimension. | ||||
|     (q_num_blocks, q_indices): This is required for the backwards pass, as | ||||
|         computing dKV requires iterating along the mask along the Q dimension. | ||||
|     [OPTIONAL](full_kv_num_blocks, full_kv_indices): This is optional, and is | ||||
|         purely an optimization. As it turns out, applying masking to every block is | ||||
|         quite expensive! If we specifically know which blocks are "full" and don't | ||||
|         require masking at all, then we can skip applying mask_mod to these blocks. | ||||
|         This requires the user to split out a separate mask_mod from the score_mod. | ||||
|         For causal masks, this is about a 15% speedup. | ||||
|     [OPTIONAL](full_q_num_blocks, full_q_indices): Same as above, but for the | ||||
|     backwards. | ||||
|     1. (kv_num_blocks, kv_indices): Used for the forwards pass of attention, as | ||||
|     we reduce along the KV dimension. | ||||
| 
 | ||||
|     2. (q_num_blocks, q_indices): Required for the backwards pass, as computing | ||||
|     dKV requires iterating along the mask along the Q dimension. | ||||
| 
 | ||||
|     3. [OPTIONAL] (full_kv_num_blocks, full_kv_indices): This is optional and | ||||
|     purely an optimization. As it turns out, applying masking to every block | ||||
|     is quite expensive! If we specifically know which blocks are "full" and | ||||
|     don't require masking at all, then we can skip applying mask_mod to these | ||||
|     blocks. This requires the user to split out a separate mask_mod from the | ||||
|     score_mod. For causal masks, this is about a 15% speedup. | ||||
| 
 | ||||
|     4. [OPTIONAL] (full_q_num_blocks, full_q_indices): Same as above, but for | ||||
|     the backwards pass. | ||||
|     """ | ||||
|     kv_num_blocks: Tensor | ||||
|     kv_indices: Tensor | ||||
| @ -184,7 +204,7 @@ class BlockMask: | ||||
|         full_q_indices: Optional[Tensor], | ||||
|         KV_BLOCK_SIZE=_DEFAULT_SPARSE_BLOCK_SIZE, | ||||
|         Q_BLOCK_SIZE=_DEFAULT_SPARSE_BLOCK_SIZE, | ||||
|         mask_fn=None, | ||||
|         mask_fn: Optional[_mask_fn_signature] = None, | ||||
|     ): | ||||
|         if kv_indices.dim() < 2: | ||||
|             raise RuntimeError("BlockMask must have at least 2 dimensions") | ||||
| @ -469,7 +489,7 @@ def create_mask( | ||||
|     r"""This function creates a mask tensor from a mod_fn function. | ||||
| 
 | ||||
|     Args: | ||||
|         mod_fn (Callable): Function to modify attention scores. | ||||
|         mod_fn (Union[_score_mod_signature, _mask_fn_signature]): Function to modify attention scores. | ||||
|         B (int): Batch size. | ||||
|         H (int): Number of heads. | ||||
|         M (int): Sequence length of query. | ||||
| @ -491,16 +511,16 @@ def create_mask( | ||||
|         ctx = nullcontext() | ||||
|     else: | ||||
|         ctx = TransformGetItemToIndex()  # type: ignore[assignment] | ||||
|     mod_type = get_mod_type(mod_fn) | ||||
|     mod_type = _get_mod_type(mod_fn) | ||||
| 
 | ||||
|     with ctx: | ||||
|         if mod_type == ModificationType.SCORE_MOD: | ||||
|         if mod_type == _ModificationType.SCORE_MOD: | ||||
|             score_mod = mod_fn | ||||
|             score_mod = _vmap_for_bhqkv(score_mod, prefix=(0,))  # first input is score | ||||
|             out = score_mod(torch.zeros(B, H, M, N, device=device), b, h, m, n) | ||||
|             mask = torch.where(torch.isneginf(out), False, True) | ||||
|             return mask | ||||
|         elif mod_type == ModificationType.MASK_FN: | ||||
|         elif mod_type == _ModificationType.MASK_FN: | ||||
|             mask_fn = mod_fn | ||||
|             mask_fn = _vmap_for_bhqkv(mask_fn, prefix=()) | ||||
|             mask = mask_fn(b, h, m, n) | ||||
| @ -515,8 +535,8 @@ def _create_block_mask_inner( | ||||
|     mod_fn, B, H, M, N, device, KV_BLOCK_SIZE, Q_BLOCK_SIZE, mod_type | ||||
| ): | ||||
|     mask_tensor = create_mask(mod_fn, B, H, M, N, device, _compile=True) | ||||
|     mod_type = get_mod_type(mod_fn) | ||||
|     if mod_type == ModificationType.MASK_FN: | ||||
|     mod_type = _get_mod_type(mod_fn) | ||||
|     if mod_type == _ModificationType.MASK_FN: | ||||
|         mask_fn = mod_fn | ||||
|     else: | ||||
|         mask_fn = None | ||||
| @ -558,7 +578,7 @@ def create_block_mask( | ||||
|         block_mask (tuple): A tuple of (kv_num_blocks, kv_indices, q_num_blocks, q_indices, | ||||
|                             KV_BLOCK_SIZE, Q_BLOCK_SIZE) which represents the block mask. | ||||
|     """ | ||||
|     mod_type = get_mod_type(fn) | ||||
|     mod_type = _get_mod_type(fn) | ||||
|     inner_func = _create_block_mask_inner | ||||
|     # This is kind of a temporary hack to workaround some issues | ||||
|     if _compile: | ||||
| @ -618,14 +638,14 @@ def flex_attention( | ||||
|             score: Tensor, | ||||
|             batch: Tensor, | ||||
|             head: Tensor, | ||||
|             token_q: Tensor, | ||||
|             token_kv: Tensor | ||||
|             q_idx: Tensor, | ||||
|             kv_idx: Tensor | ||||
|         ) -> Tensor: | ||||
| 
 | ||||
|     Where: | ||||
|         - ``score``: A scalar tensor representing the attention score, | ||||
|           with the same data type and device as the query, key, and value tensors. | ||||
|         - ``b``, ``h``, ``q_idx``, ``kv_idx``: Scalar tensors indicating | ||||
|         - ``batch``, ``head``, ``q_idx``, ``kv_idx``: Scalar tensors indicating | ||||
|           the batch index, head index, query index, and key/value index, respectively. | ||||
|           These should have the ``torch.int`` data type and be located on the same device as the score tensor. | ||||
| 
 | ||||
| @ -11,7 +11,7 @@ from torch.testing._internal.opinfo.core import ( | ||||
| ) | ||||
| from torch.testing._internal.common_dtype import all_types_and, custom_types | ||||
| from torch.testing._internal.opinfo.core import DecorateInfo | ||||
| from torch.nn.attention._flex_attention import flex_attention, _create_empty_block_mask | ||||
| from torch.nn.attention.flex_attention import flex_attention, _create_empty_block_mask | ||||
|  | ||||
| def sample_inputs_map(opinfo, device, dtype, requires_grad, **kwargs): | ||||
|     make_arg = functools.partial( | ||||
|  | ||||
		Reference in New Issue
	
	Block a user