**Summary**

Today, the only way to have variable sequence length support in PyTorch attention is through nested tensors [here](https://docs.pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html#nestedtensor-and-dense-tensor-support). We also want to add an explicit lower-level API that provides variable sequence length support without padding/masking in SDPA.

This PR builds out `varlen_attn`, the public API that users can call for the forward method, and `_varlen_attn`, the private API that calls into the Flash Attention/cuDNN backend.

**Benchmarking**

To benchmark, we compare runtime and TFLOPs against the current SDPA approach with padding.

Settings:

- 1 H100 machine
- `batch_size=8`, `max_seq_len=2048`, `embed_dim=1024`, `num_heads=16`
- dtype `torch.bfloat16`
- `is_causal=False`
- for variable length, we set sequences to be random multiples of 64 up to `max_seq_len`
- 100 runs

|        | Variable Length API | SDPA     |
|--------|--------------------|----------|
| Runtime | 0.21750560760498047 ms       | 0.43171775817871094 ms  |
| TFLOPs | 231.812         | 320.840  |

The sparsity is 0.453 which we can see matches the speedup we get from Varlen (approx 50%). TFLOPs remains around the same, with SDPA slightly larger due to potential higher overhead and total flops scaling with sequence length.

**Testing**

Run `python test/test_varlen_attention.py` for unit tests where we verify basic functionality and confirm numerical match between varlen outputs vs SDPA.

**Next steps**

Next steps from this PR (higher in the stack) include registering the private API `_varlen_attn` as a custom op, implementing backward support, and enabling cuDNN with correct numerics.

(This stack builds on top of #162326)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164502
Approved by: https://github.com/v0i0, https://github.com/drisspg
This commit is contained in:
Angel Li
2025-10-15 08:10:58 -07:00
committed by PyTorch MergeBot
parent 2b71b62045
commit 78f5a1ec60
5 changed files with 421 additions and 1 deletions

View File

@ -23,6 +23,7 @@ Submodules
flex_attention
bias
experimental
varlen
.. toctree::
:hidden:
@ -30,3 +31,4 @@ Submodules
nn.attention.flex_attention
nn.attention.bias
nn.attention.experimental
nn.attention.varlen

View File

@ -0,0 +1,17 @@
```{eval-rst}
.. role:: hidden
:class: hidden-section
```
# torch.nn.attention.varlen
```{eval-rst}
.. automodule:: torch.nn.attention.varlen
.. currentmodule:: torch.nn.attention.varlen
```
```{eval-rst}
.. autofunction:: varlen_attn
```
```{eval-rst}
.. autoclass:: AuxRequest
```

View File

@ -0,0 +1,195 @@
# Owner(s): ["module: sdpa"]
import unittest
from collections import namedtuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.attention import varlen_attn
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_nn import NNTestCase
from torch.testing._internal.common_utils import parametrize, run_tests
VarlenShape = namedtuple(
"VarlenShape", ["batch_size", "max_seq_len", "embed_dim", "num_heads"]
)
default_tolerances = {
torch.float16: {"atol": 1e-1, "rtol": 1e-1},
torch.bfloat16: {"atol": 9e-2, "rtol": 5e-2},
torch.float32: {"atol": 1e-5, "rtol": 1.3e-6},
}
class AttentionBlock(nn.Module):
def __init__(
self, embed_dim: int, num_heads: int, device: torch.device, dtype: torch.dtype
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.qkv_proj = nn.Linear(
embed_dim, 3 * embed_dim, bias=False, device=device, dtype=dtype
)
self.out_proj = nn.Linear(
embed_dim, embed_dim, bias=False, device=device, dtype=dtype
)
def forward_varlen(
self,
x_packed: torch.Tensor,
cu_seq: torch.Tensor,
max_len: int,
is_causal: bool = False,
):
qkv = self.qkv_proj(x_packed)
q, k, v = qkv.chunk(3, dim=-1)
q = q.view(-1, self.num_heads, self.head_dim)
k = k.view(-1, self.num_heads, self.head_dim)
v = v.view(-1, self.num_heads, self.head_dim)
attn_out = varlen_attn(
q, k, v, cu_seq, cu_seq, max_len, max_len, is_causal=is_causal
)
attn_out = attn_out.view(-1, self.embed_dim)
return self.out_proj(attn_out)
def forward_sdpa(self, x_padded: torch.Tensor, is_causal: bool = False):
batch_size, seq_len, _ = x_padded.shape
qkv = self.qkv_proj(x_padded)
q, k, v = qkv.chunk(3, dim=-1)
q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
attn_out = F.scaled_dot_product_attention(q, k, v, is_causal=is_causal)
attn_out = (
attn_out.transpose(1, 2)
.contiguous()
.view(batch_size, seq_len, self.embed_dim)
)
return self.out_proj(attn_out)
def create_variable_length_batch(
shape: VarlenShape, device: torch.device, dtype: torch.dtype
):
seq_lengths = []
for _ in range(shape.batch_size):
length = torch.randint(1, shape.max_seq_len // 64 + 1, (1,)).item() * 64
seq_lengths.append(min(length, shape.max_seq_len))
seq_lengths = torch.tensor(seq_lengths, device=device)
total_tokens = seq_lengths.sum().item()
x_packed = torch.randn(total_tokens, shape.embed_dim, device=device, dtype=dtype)
cu_seq = torch.zeros(shape.batch_size + 1, device=device, dtype=torch.int32)
cu_seq[1:] = seq_lengths.cumsum(0)
max_len = seq_lengths.max().item()
x_padded = torch.zeros(
shape.batch_size, max_len, shape.embed_dim, device=device, dtype=dtype
)
start_idx = 0
for i, seq_len in enumerate(seq_lengths):
end_idx = start_idx + seq_len
x_padded[i, :seq_len] = x_packed[start_idx:end_idx]
start_idx = end_idx
return {
"seq_lengths": seq_lengths,
"cu_seq": cu_seq,
"x_packed": x_packed,
"x_padded": x_padded,
"max_len": max_len,
"total_tokens": total_tokens,
}
class TestVarlenAttention(NNTestCase):
@unittest.skipIf(
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention not supported"
)
@parametrize("dtype", [torch.bfloat16, torch.float16])
def test_basic_functionality(self, device, dtype):
torch.manual_seed(42)
shape = VarlenShape(batch_size=2, max_seq_len=512, embed_dim=1024, num_heads=16)
attention_block = AttentionBlock(
shape.embed_dim, shape.num_heads, device, dtype
)
total_tokens = shape.batch_size * shape.max_seq_len
x_packed = torch.randn(
total_tokens, shape.embed_dim, device=device, dtype=dtype
)
cu_seq = torch.tensor(
[0, shape.max_seq_len, total_tokens], device=device, dtype=torch.int32
)
output = attention_block.forward_varlen(
x_packed, cu_seq, shape.max_seq_len, is_causal=False
)
self.assertEqual(output.shape, (total_tokens, shape.embed_dim))
self.assertEqual(output.device, torch.device(device))
self.assertEqual(output.dtype, dtype)
@unittest.skipIf(
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention not supported"
)
@parametrize("dtype", [torch.bfloat16, torch.float16])
@parametrize("is_causal", [False, True])
def test_varlen_vs_sdpa(self, device, dtype, is_causal):
torch.manual_seed(42)
shape = VarlenShape(
batch_size=8, max_seq_len=2048, embed_dim=1024, num_heads=16
)
attention_block = AttentionBlock(
shape.embed_dim, shape.num_heads, device, dtype
)
variable_length_batch_data = create_variable_length_batch(shape, device, dtype)
varlen_output = attention_block.forward_varlen(
variable_length_batch_data["x_packed"],
variable_length_batch_data["cu_seq"],
variable_length_batch_data["max_len"],
is_causal=is_causal,
)
sdpa_output = attention_block.forward_sdpa(
variable_length_batch_data["x_padded"], is_causal=is_causal
)
tolerances = default_tolerances[dtype]
start_idx = 0
for i, seq_len in enumerate(variable_length_batch_data["seq_lengths"]):
end_idx = start_idx + seq_len
varlen_seq = varlen_output[start_idx:end_idx]
sdpa_seq = sdpa_output[i, :seq_len]
torch.testing.assert_close(varlen_seq, sdpa_seq, **tolerances)
start_idx = end_idx
device_types = ("cuda",)
instantiate_device_type_tests(TestVarlenAttention, globals(), only_for=device_types)
if __name__ == "__main__":
run_tests()

View File

@ -14,8 +14,15 @@ from torch.backends.cuda import (
SDPAParams,
)
from .varlen import varlen_attn
__all__: list[str] = ["SDPBackend", "sdpa_kernel", "WARN_FOR_UNFUSED_KERNELS"]
__all__: list[str] = [
"SDPBackend",
"sdpa_kernel",
"WARN_FOR_UNFUSED_KERNELS",
"varlen_attn",
]
# Note: [SDPA warnings]
# TODO: Consider using this for sdpa regardless of subclasses

View File

@ -0,0 +1,199 @@
"""
Variable-length attention implementation using Flash Attention.
This module provides a high-level Python interface for variable-length attention
that calls into the optimized Flash Attention kernels.
"""
import logging
from functools import lru_cache
from typing import NamedTuple, Optional, Union
import torch
log = logging.getLogger(__name__)
__all__ = ["varlen_attn", "AuxRequest"]
@lru_cache(maxsize=8)
def _should_use_cudnn(device_index: int) -> bool:
"""Cache device capability check to avoid repeated CUDA calls."""
return False
class AuxRequest(NamedTuple):
"""
Request which auxiliary outputs to compute from varlen_attn.
Each field is a boolean indicating whether that auxiliary output should be computed.
"""
lse: bool = False
# import failures when I try to register as custom op
# @torch.library.custom_op("torch_nn_attention::_varlen_attn", mutates_args={})
def _varlen_attn(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seq_q: torch.Tensor,
cu_seq_k: torch.Tensor,
max_q: int,
max_k: int,
is_causal: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Private custom op for variable-length attention.
This is the internal implementation. Users should use the public varlen_attn function instead.
"""
use_cudnn = query.is_cuda and _should_use_cudnn(query.device.index)
if use_cudnn:
log.info("Using cuDNN backend for varlen_attn")
result = torch.ops.aten._cudnn_attention_forward(
query,
key,
value,
None, # attn_bias
cu_seq_q,
cu_seq_k,
max_q,
max_k,
True, # compute_log_sumexp
0.0, # dropout_p hardcoded to 0.0
is_causal,
False, # return_debug_mask
)
# cuDNN returns: (output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask)
output, softmax_lse = result[0], result[1]
else:
log.info("Using Flash Attention backend for varlen_attn")
output, softmax_lse, rng_state, _, _ = torch.ops.aten._flash_attention_forward(
query,
key,
value,
cu_seq_q,
cu_seq_k,
max_q,
max_k,
0.0, # dropout_p hardcoded to 0.0
is_causal,
return_debug_mask=False,
)
return output, softmax_lse
# @_varlen_attn.register_fake
def _varlen_attn_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seq_q: torch.Tensor,
cu_seq_k: torch.Tensor,
max_q: int,
max_k: int,
is_causal: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Fake implementation for meta tensor computation and tracing.
Based on the 3D varlen path from meta__flash_attention_forward:
- query shape: (total, num_heads, head_dim)
- logsumexp shape: (num_heads, total_q)
"""
# Output has same shape as query
output = torch.empty_like(query)
# For varlen path: logsumexp shape is (num_heads, total_q)
total_q = query.size(0)
num_heads = query.size(1)
logsumexp = torch.empty(
(num_heads, total_q), dtype=torch.float, device=query.device
)
return output, logsumexp
def varlen_attn(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seq_q: torch.Tensor,
cu_seq_k: torch.Tensor,
max_q: int,
max_k: int,
is_causal: bool = False,
return_aux: Optional[AuxRequest] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
"""
Compute variable-length attention using Flash Attention.
This function is similar to scaled_dot_product_attention but optimized for
variable-length sequences using cumulative sequence position tensors.
Args:
- query (Tensor): Query tensor; shape :math:`(T_q, H, D)`
- key (Tensor): Key tensor; shape :math:`(T_k, H, D)`
- value (Tensor): Value tensor; shape :math:`(T_k, H, D)`
- cu_seq_q (Tensor): Cumulative sequence positions for queries; shape :math:`(N+1,)`
- cu_seq_k (Tensor): Cumulative sequence positions for keys/values; shape :math:`(N+1,)`
- max_q (int): Maximum query sequence length in the batch.
- max_k (int): Maximum key/value sequence length in the batch.
- is_causal (bool, optional): If set to True, applies causal masking (default: False).
- return_aux (Optional[AuxRequest]): If not None and ``return_aux.lse`` is True, also returns the logsumexp tensor.
Shape legend:
- :math:`N`: Batch size
- :math:`T_q`: Total number of query tokens in the batch (sum of all query sequence lengths)
- :math:`T_k`: Total number of key/value tokens in the batch (sum of all key/value sequence lengths)
- :math:`H`: Number of attention heads
- :math:`D`: Head dimension
Returns:
- Tensor: Output tensor from attention computation
- If ``return_aux`` is not None and ``return_aux.lse`` is True, returns a tuple of Tensors:
(output, lse), where lse is the logsumexp
Example::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
>>> batch_size, max_seq_len, embed_dim, num_heads = 2, 512, 1024, 16
>>> head_dim = embed_dim // num_heads
>>> seq_lengths = []
>>> for _ in range(batch_size):
... length = torch.randint(1, max_seq_len // 64 + 1, (1,)).item() * 64
... seq_lengths.append(min(length, max_seq_len))
>>> seq_lengths = torch.tensor(seq_lengths, device="cuda")
>>> total_tokens = seq_lengths.sum().item()
>>>
>>> # Create packed query, key, value tensors
>>> query = torch.randn(
... total_tokens, num_heads, head_dim, dtype=torch.float16, device="cuda"
... )
>>> key = torch.randn(
... total_tokens, num_heads, head_dim, dtype=torch.float16, device="cuda"
... )
>>> value = torch.randn(
... total_tokens, num_heads, head_dim, dtype=torch.float16, device="cuda"
... )
>>>
>>> # Build cumulative sequence tensor
>>> cu_seq = torch.zeros(batch_size + 1, device="cuda", dtype=torch.int32)
>>> cu_seq[1:] = seq_lengths.cumsum(0)
>>> max_len = seq_lengths.max().item()
>>>
>>> # Call varlen_attn
>>> output = varlen_attn(
... query, key, value, cu_seq, cu_seq, max_len, max_len, is_causal=False
... )
"""
out, lse = _varlen_attn(
query, key, value, cu_seq_q, cu_seq_k, max_q, max_k, is_causal
)
if return_aux is not None and return_aux.lse:
return out, lse
return out