Compare commits

...

4 Commits

Author SHA1 Message Date
1d9c98197e debugging cudnn numerics
ghstack-source-id: 606aeb4f2df49a79971047d83e6147cc958640af
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164950
2025-10-10 10:47:31 -07:00
3ad95533cb bwd pass
ghstack-source-id: 15eecab55e17a591636510e4ef2bce8392201839
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164504
2025-10-10 10:47:30 -07:00
a5026055bf register custom op
ghstack-source-id: 40f85bb85627ff56c33c754fc8ec7a27d5e84ee2
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164503
2025-10-10 10:47:30 -07:00
56dbf23bb8 varlen api
ghstack-source-id: 19289b702e3aa23ee4e254fb30b452cc52b5a7a7
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164502
2025-10-10 10:47:29 -07:00
3 changed files with 613 additions and 1 deletions

View File

@ -0,0 +1,343 @@
import unittest
from collections import namedtuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.attention.varlen import varlen_attn
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_nn import NNTestCase
from torch.testing._internal.common_utils import parametrize, run_tests
from torch.utils._python_dispatch import TorchDispatchMode
VarlenShape = namedtuple(
"VarlenShape", ["batch_size", "max_seq_len", "embed_dim", "num_heads"]
)
default_tolerances = {
torch.float16: {"atol": 1e-1, "rtol": 1e-1},
torch.bfloat16: {"atol": 9e-2, "rtol": 5e-2},
torch.float32: {"atol": 1e-5, "rtol": 1.3e-6},
}
class OpLoggingMode(TorchDispatchMode):
"""Logging mode that captures all dispatched operations"""
def __init__(self):
self.called_ops = []
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
op_name = str(func)
self.called_ops.append(op_name)
return func(*args, **(kwargs or {}))
class AttentionBlock(nn.Module):
def __init__(
self, embed_dim: int, num_heads: int, device: torch.device, dtype: torch.dtype
):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.qkv_proj = nn.Linear(
embed_dim, 3 * embed_dim, bias=False, device=device, dtype=dtype
)
self.out_proj = nn.Linear(
embed_dim, embed_dim, bias=False, device=device, dtype=dtype
)
# with torch.no_grad():
# self.qkv_proj.weight.zero_()
# for i in range(3):
# self.qkv_proj.weight[i*embed_dim:(i+1)*embed_dim, :] = torch.eye(
# embed_dim, device=device, dtype=dtype
# )
# self.out_proj.weight.zero_()
# self.out_proj.weight.copy_(
# torch.eye(embed_dim, device=device, dtype=dtype)
# )
def forward_varlen(
self,
x_packed: torch.Tensor,
cu_seq: torch.Tensor,
max_len: int,
is_causal: bool = False,
):
qkv = self.qkv_proj(x_packed)
q, k, v = qkv.chunk(3, dim=-1)
q = q.view(-1, self.num_heads, self.head_dim)
k = k.view(-1, self.num_heads, self.head_dim)
v = v.view(-1, self.num_heads, self.head_dim)
print(f"varlen q: {q}")
print(f"varlen k: {k}")
print(f"varlen v: {v}")
attn_out = varlen_attn(q, k, v, cu_seq, cu_seq, max_len, max_len, is_causal)
attn_out = attn_out.view(-1, self.embed_dim)
return self.out_proj(attn_out)
def forward_sdpa(
self,
x_padded: torch.Tensor,
seq_lengths: torch.Tensor,
dtype: torch.dtype,
is_causal: bool = False,
):
batch_size, seq_len, _ = x_padded.shape
qkv = self.qkv_proj(x_padded)
q, k, v = qkv.chunk(3, dim=-1)
mask = (
torch.arange(seq_len, device=x_padded.device)[None, :]
< seq_lengths[:, None]
)
attn_mask = mask[:, None, None, :].expand(
batch_size, self.num_heads, seq_len, seq_len
)
q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
print(f"sdpa q: {q}")
print(f"sdpa k: {k}")
print(f"sdpa v: {v}")
attn_out = F.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask, is_causal=is_causal
)
attn_out = (
attn_out.transpose(1, 2)
.contiguous()
.view(batch_size, seq_len, self.embed_dim)
)
return self.out_proj(attn_out)
def create_variable_length_batch(
shape: VarlenShape, device: torch.device, dtype: torch.dtype
):
seq_lengths = []
for _ in range(shape.batch_size):
length = torch.randint(1, shape.max_seq_len // 64 + 1, (1,)).item() * 64
seq_lengths.append(min(length, shape.max_seq_len))
seq_lengths = torch.tensor(seq_lengths, device=device)
total_tokens = seq_lengths.sum().item()
x_packed = torch.randn(
total_tokens, shape.embed_dim, device=device, dtype=dtype, requires_grad=True
)
cu_seq = torch.zeros(shape.batch_size + 1, device=device, dtype=torch.int32)
cu_seq[1:] = seq_lengths.cumsum(0)
max_len = seq_lengths.max().item()
x_padded = torch.zeros(
shape.batch_size, max_len, shape.embed_dim, device=device, dtype=dtype
)
start_idx = 0
for i, seq_len in enumerate(seq_lengths):
end_idx = start_idx + seq_len
x_padded[i, :seq_len] = x_packed[start_idx:end_idx]
start_idx = end_idx
x_padded = x_padded.clone().detach().requires_grad_()
return {
"seq_lengths": seq_lengths,
"cu_seq": cu_seq,
"x_packed": x_packed,
"x_padded": x_padded,
"max_len": max_len,
"total_tokens": total_tokens,
}
class TestVarlenAttention(NNTestCase):
@unittest.skipIf(
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention not supported"
)
@parametrize("dtype", [torch.bfloat16, torch.float16])
def test_basic_functionality(self, device, dtype):
torch.manual_seed(42)
shape = VarlenShape(batch_size=2, max_seq_len=512, embed_dim=1024, num_heads=16)
attention_block = AttentionBlock(
shape.embed_dim, shape.num_heads, device, dtype
)
total_tokens = shape.batch_size * shape.max_seq_len
x_packed = torch.randn(
total_tokens,
shape.embed_dim,
device=device,
dtype=dtype,
requires_grad=True,
)
cu_seq = torch.tensor(
[0, shape.max_seq_len, total_tokens], device=device, dtype=torch.int32
)
output = attention_block.forward_varlen(
x_packed, cu_seq, shape.max_seq_len, is_causal=False
)
self.assertEqual(output.shape, (total_tokens, shape.embed_dim))
self.assertEqual(output.device, torch.device(device))
self.assertEqual(output.dtype, dtype)
loss = output.sum()
loss.backward()
self.assertIsNotNone(x_packed.grad)
self.assertEqual(x_packed.grad.shape, x_packed.shape)
self.assertEqual(x_packed.grad.dtype, x_packed.dtype)
@unittest.skipIf(
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention not supported"
)
@parametrize("dtype", [torch.bfloat16, torch.float16])
def test_custom_op_registration(self, device, dtype):
torch.manual_seed(42)
shape = VarlenShape(batch_size=2, max_seq_len=512, embed_dim=1024, num_heads=16)
attention_block = AttentionBlock(
shape.embed_dim, shape.num_heads, device, dtype
)
total_tokens = shape.batch_size * shape.max_seq_len
x_packed = torch.randn(
total_tokens, shape.embed_dim, device=device, dtype=dtype
)
cu_seq = torch.tensor(
[0, shape.max_seq_len, total_tokens], device=device, dtype=torch.int32
)
compiled_forward = torch.compile(
attention_block.forward_varlen, backend="eager", fullgraph=True
)
with OpLoggingMode() as mode:
output = compiled_forward(
x_packed, cu_seq, shape.max_seq_len, is_causal=False
)
self.assertEqual(output.shape, (total_tokens, shape.embed_dim))
self.assertEqual(output.device, torch.device(device))
self.assertEqual(output.dtype, dtype)
called_ops = mode.called_ops
custom_op_called = any(
"torch_nn_attention._varlen_attn" in op for op in called_ops
)
assert custom_op_called
@unittest.skipIf(
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention not supported"
)
# @parametrize("dtype", [torch.bfloat16, torch.float16])
# @parametrize("is_causal", [False, True])
@parametrize("dtype", [torch.bfloat16])
@parametrize("is_causal", [False])
def test_varlen_vs_sdpa(self, device, dtype, is_causal):
torch.manual_seed(42)
# shape = VarlenShape(
# batch_size=8, max_seq_len=2048, embed_dim=1024, num_heads=16
# )
shape = VarlenShape(
batch_size=2, max_seq_len=128, embed_dim=32, num_heads=4
)
# shape = VarlenShape(
# batch_size=2, max_seq_len=2048, embed_dim=1024, num_heads=16
# )
attention_block = AttentionBlock(
shape.embed_dim, shape.num_heads, device, dtype
)
variable_length_batch_data = create_variable_length_batch(shape, device, dtype)
varlen_output = attention_block.forward_varlen(
variable_length_batch_data["x_packed"],
variable_length_batch_data["cu_seq"],
variable_length_batch_data["max_len"],
is_causal=is_causal,
)
sdpa_output = attention_block.forward_sdpa(
variable_length_batch_data["x_padded"],
variable_length_batch_data["seq_lengths"],
dtype=dtype,
is_causal=is_causal,
)
tolerances = default_tolerances[dtype]
start_idx = 0
for i, seq_len in enumerate(variable_length_batch_data["seq_lengths"]):
end_idx = start_idx + seq_len
varlen_seq = varlen_output[start_idx:end_idx]
sdpa_seq = sdpa_output[i, :seq_len]
print(f"varlen_seq: {varlen_seq}")
print(f"sdpa_seq: {sdpa_seq}")
torch.testing.assert_close(varlen_seq, sdpa_seq, **tolerances)
start_idx = end_idx
varlen_grad_out = torch.ones_like(varlen_output)
sdpa_grad_out = torch.zeros_like(sdpa_output)
start_idx = 0
for i, seq_len in enumerate(variable_length_batch_data["seq_lengths"]):
end_idx = start_idx + seq_len
sdpa_grad_out[i, :seq_len] = varlen_grad_out[start_idx:end_idx]
start_idx = end_idx
varlen_grad = torch.autograd.grad(
outputs=varlen_output,
inputs=variable_length_batch_data["x_packed"],
grad_outputs=varlen_grad_out,
retain_graph=True,
create_graph=False,
allow_unused=False,
)[0]
sdpa_grad = torch.autograd.grad(
outputs=sdpa_output,
inputs=variable_length_batch_data["x_padded"],
grad_outputs=sdpa_grad_out,
retain_graph=True,
create_graph=False,
allow_unused=False,
)[0]
start_idx = 0
for i, seq_len in enumerate(variable_length_batch_data["seq_lengths"]):
end_idx = start_idx + seq_len
varlen_grad_seq = varlen_grad[start_idx:end_idx]
sdpa_grad_seq = sdpa_grad[i, :seq_len]
print(f"varlen_grad_seq: {varlen_grad_seq}")
print(f"sdpa_grad_seq: {sdpa_grad_seq}")
torch.testing.assert_close(varlen_grad_seq, sdpa_grad_seq, **tolerances)
start_idx = end_idx
device_types = ("cuda",)
instantiate_device_type_tests(TestVarlenAttention, globals(), only_for=device_types)
if __name__ == "__main__":
run_tests()

View File

@ -15,7 +15,11 @@ from torch.backends.cuda import (
)
__all__: list[str] = ["SDPBackend", "sdpa_kernel", "WARN_FOR_UNFUSED_KERNELS"]
__all__: list[str] = [
"SDPBackend",
"sdpa_kernel",
"WARN_FOR_UNFUSED_KERNELS",
]
# Note: [SDPA warnings]
# TODO: Consider using this for sdpa regardless of subclasses

View File

@ -0,0 +1,265 @@
"""
Variable-length attention implementation using Flash Attention.
This module provides a high-level Python interface for variable-length attention
that calls into the optimized Flash Attention kernels.
"""
import logging
from functools import lru_cache
from typing import Union, NamedTuple, Optional
import torch
log = logging.getLogger(__name__)
__all__ = ["varlen_attn", "AuxRequest"]
@lru_cache(maxsize=8)
def _should_use_cudnn(device_index: int) -> bool:
"""Cache device capability check to avoid repeated CUDA calls."""
return True
class AuxRequest(NamedTuple):
lse: bool = False
@torch.library.custom_op("torch_nn_attention::_varlen_attn", mutates_args={})
def _varlen_attn(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seq_q: torch.Tensor,
cu_seq_k: torch.Tensor,
max_q: int,
max_k: int,
is_causal: bool = False,
attn_bias: torch.Tensor = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Private custom op for variable-length attention using Flash Attention.
This is the internal implementation that calls into the Flash Attention kernels.
Users should use the public varlen_attn function instead.
"""
use_cudnn = query.is_cuda and _should_use_cudnn(query.device.index)
if use_cudnn:
log.info("Using cuDNN backend for varlen_attn")
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
result = torch.ops.aten._cudnn_attention_forward(
query,
key,
value,
None, # attn_bias
cu_seq_q,
cu_seq_k,
max_q,
max_k,
True, # compute_log_sumexp
0.0, # dropout_p hardcoded to 0.0
is_causal,
False, # return_debug_mask
)
# cuDNN returns: (output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask)
output, softmax_lse, rng_state, philox_offset = result[0], result[1], result[6], result[7]
else:
log.info("Using Flash Attention backend for varlen_attn")
output, softmax_lse, rng_state, _, _ = torch.ops.aten._flash_attention_forward(
query,
key,
value,
cu_seq_q,
cu_seq_k,
max_q,
max_k,
0.0, # dropout_p hardcoded to 0.0
is_causal,
return_debug_mask=False,
)
philox_offset = torch.empty(0, device=query.device)
return output, softmax_lse, rng_state, philox_offset
# @_varlen_attn.register_fake
def _varlen_attn_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seq_q: torch.Tensor,
cu_seq_k: torch.Tensor,
max_q: int,
max_k: int,
is_causal: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Fake implementation for meta tensor computation and tracing.
Based on the 3D varlen path from meta__flash_attention_forward:
- query shape: (total, num_heads, head_dim)
- logsumexp shape: (num_heads, total_q)
"""
# Output has same shape as query
output = torch.empty_like(query)
# For varlen path: logsumexp shape is (num_heads, total_q)
total_q = query.size(0)
num_heads = query.size(1)
logsumexp = torch.empty(
(num_heads, total_q), dtype=torch.float, device=query.device
)
return output, logsumexp
def varlen_attn(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seq_q: torch.Tensor,
cu_seq_k: torch.Tensor,
max_q: int,
max_k: int,
is_causal: bool = False,
return_aux: Optional[AuxRequest] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
"""
Compute variable-length attention using Flash Attention.
This function is similar to scaled_dot_product_attention but optimized for
variable-length sequences using cumulative sequence position tensors.
Args:
query (Tensor): Query tensor; shape :math:`(T_q, H, D)`
key (Tensor): Key tensor; shape :math:`(T_k, H, D)`
value (Tensor): Value tensor; shape :math:`(T_k, H, D)`
cu_seq_q (Tensor): Cumulative sequence positions for queries; shape :math:`(N+1,)`
cu_seq_k (Tensor): Cumulative sequence positions for keys/values; shape :math:`(N+1,)`
max_q (int): Maximum query sequence length in the batch.
max_k (int): Maximum key/value sequence length in the batch.
is_causal (bool, optional): If set to True, applies causal masking (default: False).
return_aux (Optional[AuxRequest]): If not None and ``return_aux.lse`` is True, also returns the logsumexp tensor.
Shape legend:
- :math:`N`: Batch size
- :math:`T_q`: Total number of query tokens in the batch (sum of all query sequence lengths)
- :math:`T_k`: Total number of key/value tokens in the batch (sum of all key/value sequence lengths)
- :math:`H`: Number of attention heads
- :math:`D`: Head dimension
Returns:
Tensor: Output tensor from attention computation
If ``return_aux`` is not None and ``return_aux.lse`` is True, returns a tuple of Tensors:
(output, lse), where lse is the logsumexp
Example:
>>> batch_size, max_seq_len, embed_dim, num_heads = 2, 512, 1024, 16
>>> head_dim = embed_dim // num_heads
>>> seq_lengths = []
>>> for _ in range(shape.batch_size):
>>> length = torch.randint(1, shape.max_seq_len // 64 + 1, (1,)).item() * 64
>>> seq_lengths.append(min(length, shape.max_seq_len))
>>> seq_lengths = torch.tensor(seq_lengths, device=device)
>>> total_tokens = seq_lengths.sum().item()
>>> # Create packed query, key, value tensors
>>> query = torch.randn(total_tokens, num_heads, head_dim, dtype=torch.float16, device="cuda")
>>> key = torch.randn(total_tokens, num_heads, head_dim, dtype=torch.float16, device="cuda")
>>> value = torch.randn(total_tokens, num_heads, head_dim, dtype=torch.float16, device="cuda")
>>> # Build cumulative sequence tensor
>>> cu_seq = torch.zeros(shape.batch_size + 1, device=device, dtype=torch.int32)
>>> cu_seq[1:] = seq_lengths.cumsum(0)
>>> max_len = seq_lengths.max().item()
>>> # Call varlen_attn
>>> output = varlen_attn(query, key, value, cu_seq, cu_seq, max_len, max_len, is_causal=False)
"""
out, lse, _, _ = torch.ops.torch_nn_attention._varlen_attn(
query, key, value, cu_seq_q, cu_seq_k, max_q, max_k, is_causal
)
if return_aux is not None and return_aux.lse:
return out, lse
return out
def setup_context(ctx, inputs, output):
query, key, value, cu_seq_q, cu_seq_k, max_q, max_k, is_causal, attn_bias = inputs
out, lse, rng_state, philox_offset = output
ctx.query = query
ctx.key = key
ctx.value = value
ctx.cu_seq_q = cu_seq_q
ctx.cu_seq_k = cu_seq_k
ctx.max_q = max_q
ctx.max_k = max_k
ctx.is_causal = is_causal
ctx.attn_bias = attn_bias
ctx.output = out
ctx.lse = lse
ctx.rng_state = rng_state
ctx.philox_offset = philox_offset
def backward(ctx, grad_out, grad_lse, grad_rng, grad_philox_offset):
query = ctx.query
key = ctx.key
value = ctx.value
cu_seq_q = ctx.cu_seq_q
cu_seq_k = ctx.cu_seq_k
max_q = ctx.max_q
max_k = ctx.max_k
is_causal = ctx.is_causal
attn_bias = ctx.attn_bias
out = ctx.output
lse = ctx.lse
rng_state = getattr(ctx, "rng_state", torch.empty(0, device=query.device))
philox_offset = getattr(ctx, "philox_offset", torch.empty(0, device=query.device))
unused = torch.empty(0, device=query.device)
use_cudnn = query.is_cuda and _should_use_cudnn(query.device.index)
if use_cudnn:
log.info("Using cuDNN backend for varlen_attn")
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
dq, dk, dv = torch.ops.aten._cudnn_attention_backward(
grad_out,
query,
key,
value,
out,
lse,
rng_state,
philox_offset,
attn_bias,
cu_seq_q,
cu_seq_k,
max_q,
max_k,
0.0,
is_causal,
)
else:
log.info("Using Flash Attention backend for varlen_attn")
dq, dk, dv = torch.ops.aten._flash_attention_backward(
grad_out,
query,
key,
value,
out,
lse,
cu_seq_q,
cu_seq_k,
max_q,
max_k,
0.0,
is_causal,
rng_state,
unused,
)
return dq, dk, dv, None, None, None, None, None, None
_varlen_attn.register_autograd(backward, setup_context=setup_context)