**Summary**
Today, the only way to have variable sequence length support in PyTorch attention is through nested tensors [here](https://docs.pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html#nestedtensor-and-dense-tensor-support). We also want to add an explicit lower-level API that provides variable sequence length support without padding/masking in SDPA.
This PR builds out `varlen_attn`, the public API that users can call for the forward method, and `_varlen_attn`, the private API that calls into the Flash Attention/cuDNN backend.
**Benchmarking**
To benchmark, we compare runtime and TFLOPs against the current SDPA approach with padding.
Settings:
- 1 H100 machine
- `batch_size=8`, `max_seq_len=2048`, `embed_dim=1024`, `num_heads=16`
- dtype `torch.bfloat16`
- `is_causal=False`
- for variable length, we set sequences to be random multiples of 64 up to `max_seq_len`
- 100 runs
| | Variable Length API | SDPA |
|--------|--------------------|----------|
| Runtime | 0.21750560760498047 ms | 0.43171775817871094 ms |
| TFLOPs | 231.812 | 320.840 |
The sparsity is 0.453 which we can see matches the speedup we get from Varlen (approx 50%). TFLOPs remains around the same, with SDPA slightly larger due to potential higher overhead and total flops scaling with sequence length.
**Testing**
Run `python test/test_varlen_attention.py` for unit tests where we verify basic functionality and confirm numerical match between varlen outputs vs SDPA.
**Next steps**
Next steps from this PR (higher in the stack) include registering the private API `_varlen_attn` as a custom op, implementing backward support, and enabling cuDNN with correct numerics.
(This stack builds on top of #162326)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164502
Approved by: https://github.com/v0i0, https://github.com/drisspg
**Summary**
Today, the only way to have variable sequence length support in PyTorch attention is through nested tensors [here](https://docs.pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html#nestedtensor-and-dense-tensor-support). We also want to add an explicit lower-level API that provides variable sequence length support without padding/masking in SDPA.
This PR builds out `varlen_attn`, the public API that users can call for the forward method, and `_varlen_attn`, the private API that calls into the Flash Attention/cuDNN backend.
**Benchmarking**
To benchmark, we compare runtime and TFLOPs against the current SDPA approach with padding.
Settings:
- 1 H100 machine
- `batch_size=8`, `max_seq_len=2048`, `embed_dim=1024`, `num_heads=16`
- dtype `torch.bfloat16`
- `is_causal=False`
- for variable length, we set sequences to be random multiples of 64 up to `max_seq_len`
- 100 runs
| | Variable Length API | SDPA |
|--------|--------------------|----------|
| Runtime | 0.21750560760498047 ms | 0.43171775817871094 ms |
| TFLOPs | 231.812 | 320.840 |
The sparsity is 0.453 which we can see matches the speedup we get from Varlen (approx 50%). TFLOPs remains around the same, with SDPA slightly larger due to potential higher overhead and total flops scaling with sequence length.
**Testing**
Run `python test/test_varlen_attention.py` for unit tests where we verify basic functionality and confirm numerical match between varlen outputs vs SDPA.
**Next steps**
Next steps from this PR (higher in the stack) include registering the private API `_varlen_attn` as a custom op, implementing backward support, and enabling cuDNN with correct numerics.
(This stack builds on top of #162326)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164502
Approved by: https://github.com/v0i0, https://github.com/drisspg