**Summary**

Today, the only way to have variable sequence length support in PyTorch attention is through nested tensors [here](https://docs.pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html#nestedtensor-and-dense-tensor-support). We also want to add an explicit lower-level API that provides variable sequence length support without padding/masking in SDPA.

This PR builds out `varlen_attn`, the public API that users can call for the forward method, and `_varlen_attn`, the private API that calls into the Flash Attention/cuDNN backend.

**Benchmarking**

To benchmark, we compare runtime and TFLOPs against the current SDPA approach with padding.

Settings:

- 1 H100 machine
- `batch_size=8`, `max_seq_len=2048`, `embed_dim=1024`, `num_heads=16`
- dtype `torch.bfloat16`
- `is_causal=False`
- for variable length, we set sequences to be random multiples of 64 up to `max_seq_len`
- 100 runs

|        | Variable Length API | SDPA     |
|--------|--------------------|----------|
| Runtime | 0.21750560760498047 ms       | 0.43171775817871094 ms  |
| TFLOPs | 231.812         | 320.840  |

The sparsity is 0.453 which we can see matches the speedup we get from Varlen (approx 50%). TFLOPs remains around the same, with SDPA slightly larger due to potential higher overhead and total flops scaling with sequence length.

**Testing**

Run `python test/test_varlen_attention.py` for unit tests where we verify basic functionality and confirm numerical match between varlen outputs vs SDPA.

**Next steps**

Next steps from this PR (higher in the stack) include registering the private API `_varlen_attn` as a custom op, implementing backward support, and enabling cuDNN with correct numerics.

(This stack builds on top of #162326)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164502
Approved by: https://github.com/v0i0, https://github.com/drisspg
This commit is contained in:
Angel Li
2025-10-15 08:10:58 -07:00
committed by PyTorch MergeBot
parent 2b71b62045
commit 78f5a1ec60
5 changed files with 421 additions and 1 deletions

View File

@ -23,6 +23,7 @@ Submodules
flex_attention
bias
experimental
varlen
.. toctree::
:hidden:
@ -30,3 +31,4 @@ Submodules
nn.attention.flex_attention
nn.attention.bias
nn.attention.experimental
nn.attention.varlen

View File

@ -0,0 +1,17 @@
```{eval-rst}
.. role:: hidden
:class: hidden-section
```
# torch.nn.attention.varlen
```{eval-rst}
.. automodule:: torch.nn.attention.varlen
.. currentmodule:: torch.nn.attention.varlen
```
```{eval-rst}
.. autofunction:: varlen_attn
```
```{eval-rst}
.. autoclass:: AuxRequest
```