mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
This PR adds FlexAttention + NJT support. In particular: * To handle raggedness, treats the packed sequence dim of input NJTs as a giant "stacked sequence". To ensure user `score_mod` / `mask_mod` functions can still be written in the original NJT sequence space, this PR handles conversions for indices within the giant "stacked sequence" -> sequence relative indices automatically. * Provides `py_impls` for `NestedTensor` to the HOPs for flex attention forward / backward that simply wrap / unwrap NJTs appropriately * Adds barebones `new_empty()` support to NJT since FlexAttention utilizes this repeatedly; right now, only `new_empty()` with a shape of `()` is supported * Tests that FlexAttention with a causal mask matches causal SDPA * Adds a new public API for FlexAttention usage: * `create_nested_block_mask(mask_mod, B, H, njt, BLOCK_SIZE, _compile)` - NJT analogue for `create_block_mask()` that utilizes the `njt`'s ragged structure to create an appropriately-sized block mask (e.g. `(1, 1, total_seqlen, total_seqlen)`). This function handles the index conversion from "stacked sequence" space -> relative sequence space. * Minor note: as this is a public API, this function is purposefully named with "nested" instead of "njt" to keep the latter as an informal, mostly internal-only term. Example usage: ```python def causal_mask(b, h, q_idx, kv_idx): return q_idx >= kv_idx query = ... # NJT of shape (B, H, S*, D) key = ... # NJT of shape (B, H, S*, D) value = ... # NJT of shape (B, H, S*, D) # create_nested_block_mask() automatically converts indices from "stacked sequence" space -> relative sequence space block_mask = create_nested_block_mask(causal_mask, 1, 1, query) # block mask conceptual shape is (B, H, sum(S*), sum(S*)) output = flex_attention(query, key, value, block_mask=block_mask) def causal_score_mod(score, b, h, q_idx, kv_idx): return torch.where(q_idx >= kv_idx, score, float("-inf")) # flex_attention() automatically converts indices from "stacked sequence" space -> relative sequence space for NJT inputs output2 = flex_attention(query, key, value, score_mod=causal_score_mod) ``` TODO: * ~~Determine the right level of abstraction for public API helpers + move them alongside other helpers~~ Verify this with others though * ~~Some cleanup~~ * ~~`njt_score_mod_adapter`~~ * ~~Q: should `create_njt_block_mask()` call `njt_mask_mod_adapter()` so we don't need two calls?~~ * Can we avoid materializing the `sum(s)` length `seq_idx` used for conversion between stacked sequence -> sequence relative indices? * Not for now, although future work may deepen the integration between Flex + NJT (possibly requiring custom templates). We should try to cache this though. * ~~Demonstrate non-causal mask~~ * Support non-contiguous NJTs with holes (**booted to future PR**) Pull Request resolved: https://github.com/pytorch/pytorch/pull/136792 Approved by: https://github.com/drisspg ghstack dependencies: #138841
28 lines
608 B
ReStructuredText
28 lines
608 B
ReStructuredText
.. role:: hidden
|
|
:class: hidden-section
|
|
|
|
======================================
|
|
torch.nn.attention.flex_attention
|
|
======================================
|
|
|
|
.. currentmodule:: torch.nn.attention.flex_attention
|
|
.. py:module:: torch.nn.attention.flex_attention
|
|
.. autofunction:: flex_attention
|
|
|
|
BlockMask Utilities
|
|
-------------------
|
|
|
|
.. autofunction:: create_block_mask
|
|
.. autofunction:: create_mask
|
|
.. autofunction:: create_nested_block_mask
|
|
.. autofunction:: and_masks
|
|
.. autofunction:: or_masks
|
|
.. autofunction:: noop_mask
|
|
|
|
BlockMask
|
|
---------
|
|
|
|
.. autoclass:: BlockMask
|
|
:members:
|
|
:undoc-members:
|