Compare commits

...

2 Commits

Author SHA1 Message Date
45715eb46e debugging cudnn numerics
ghstack-source-id: 460fd38569b797bdd607f6672aa16f35177aa5c8
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164950
2025-10-29 14:10:53 -07:00
22c7937326 bwd pass
ghstack-source-id: 563ff6899659ecced546e3723410732f5fc2878f
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164504
2025-10-29 14:10:52 -07:00
3 changed files with 432 additions and 41 deletions

View File

@ -5,22 +5,29 @@ from collections import namedtuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.attention import varlen_attn
from torch.nn.attention.varlen import varlen_attn
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_nn import NNTestCase
from torch.testing._internal.common_utils import parametrize, run_tests
from torch.testing._internal.common_utils import parametrize, run_tests, skipIfRocm
from torch.utils._python_dispatch import TorchDispatchMode
VarlenShape = namedtuple(
"VarlenShape", ["batch_size", "max_seq_len", "embed_dim", "num_heads"]
)
default_tolerances = {
torch.float16: {"atol": 1e-1, "rtol": 1e-1},
torch.bfloat16: {"atol": 9e-2, "rtol": 5e-2},
torch.float32: {"atol": 1e-5, "rtol": 1.3e-6},
}
class OpLoggingMode(TorchDispatchMode):
"""Logging mode that captures all dispatched operations"""
def __init__(self):
self.called_ops = []
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
op_name = str(func)
self.called_ops.append(op_name)
return func(*args, **(kwargs or {}))
class AttentionBlock(nn.Module):
@ -39,12 +46,9 @@ class AttentionBlock(nn.Module):
embed_dim, embed_dim, bias=False, device=device, dtype=dtype
)
def forward_varlen(
def get_varlen_qkv(
self,
x_packed: torch.Tensor,
cu_seq: torch.Tensor,
max_len: int,
is_causal: bool = False,
):
qkv = self.qkv_proj(x_packed)
q, k, v = qkv.chunk(3, dim=-1)
@ -53,24 +57,56 @@ class AttentionBlock(nn.Module):
k = k.view(-1, self.num_heads, self.head_dim)
v = v.view(-1, self.num_heads, self.head_dim)
attn_out = varlen_attn(
q, k, v, cu_seq, cu_seq, max_len, max_len, is_causal=is_causal
)
return q, k, v
def forward_varlen(
self,
x_packed: torch.Tensor,
cu_seq: torch.Tensor,
max_len: int,
is_causal: bool = False,
):
q, k, v = self.get_varlen_qkv(x_packed)
attn_out = varlen_attn(q, k, v, cu_seq, cu_seq, max_len, max_len, is_causal)
attn_out = attn_out.view(-1, self.embed_dim)
return self.out_proj(attn_out)
def forward_sdpa(self, x_padded: torch.Tensor, is_causal: bool = False):
def forward_sdpa(
self,
x_padded: torch.Tensor,
seq_lengths: torch.Tensor,
is_causal: bool = False,
):
batch_size, seq_len, _ = x_padded.shape
qkv = self.qkv_proj(x_padded)
q, k, v = qkv.chunk(3, dim=-1)
mask = (
torch.arange(seq_len, device=x_padded.device)[None, :]
< seq_lengths[:, None]
)
attn_mask = mask[:, None, None, :].expand(
batch_size, self.num_heads, seq_len, seq_len
)
q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
attn_out = F.scaled_dot_product_attention(q, k, v, is_causal=is_causal)
if is_causal:
causal_mask = torch.triu(
torch.ones(seq_len, seq_len, device=x_padded.device, dtype=torch.bool),
diagonal=1,
)
combined_mask = causal_mask[None, None, :, :] | ~attn_mask
attn_out = F.scaled_dot_product_attention(q, k, v, attn_mask=~combined_mask)
else:
attn_out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)
attn_out = (
attn_out.transpose(1, 2)
.contiguous()
@ -91,7 +127,9 @@ def create_variable_length_batch(
seq_lengths = torch.tensor(seq_lengths, device=device)
total_tokens = seq_lengths.sum().item()
x_packed = torch.randn(total_tokens, shape.embed_dim, device=device, dtype=dtype)
x_packed = torch.randn(
total_tokens, shape.embed_dim, device=device, dtype=dtype, requires_grad=True
)
cu_seq = torch.zeros(shape.batch_size + 1, device=device, dtype=torch.int32)
cu_seq[1:] = seq_lengths.cumsum(0)
@ -106,6 +144,7 @@ def create_variable_length_batch(
end_idx = start_idx + seq_len
x_padded[i, :seq_len] = x_packed[start_idx:end_idx]
start_idx = end_idx
x_padded = x_padded.clone().detach().requires_grad_()
return {
"seq_lengths": seq_lengths,
@ -118,6 +157,7 @@ def create_variable_length_batch(
class TestVarlenAttention(NNTestCase):
@skipIfRocm(msg="ROCM does not support variable length attention")
@unittest.skipIf(
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention not supported"
)
@ -133,7 +173,11 @@ class TestVarlenAttention(NNTestCase):
total_tokens = shape.batch_size * shape.max_seq_len
x_packed = torch.randn(
total_tokens, shape.embed_dim, device=device, dtype=dtype
total_tokens,
shape.embed_dim,
device=device,
dtype=dtype,
requires_grad=True,
)
cu_seq = torch.tensor(
[0, shape.max_seq_len, total_tokens], device=device, dtype=torch.int32
@ -147,6 +191,131 @@ class TestVarlenAttention(NNTestCase):
self.assertEqual(output.device, torch.device(device))
self.assertEqual(output.dtype, dtype)
varlen_grad_out = torch.ones_like(output)
varlen_grad = torch.autograd.grad(
outputs=output,
inputs=x_packed,
grad_outputs=varlen_grad_out,
retain_graph=True,
create_graph=False,
allow_unused=False,
)[0]
self.assertIsNotNone(varlen_grad)
self.assertEqual(varlen_grad.shape, x_packed.shape)
self.assertEqual(varlen_grad.dtype, x_packed.dtype)
@skipIfRocm(msg="ROCM does not support variable length attention")
@unittest.skipIf(
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention not supported"
)
@parametrize("dtype", [torch.bfloat16, torch.float16])
def test_custom_op_compliance(self, device, dtype):
torch.manual_seed(42)
shape = VarlenShape(batch_size=2, max_seq_len=512, embed_dim=1024, num_heads=16)
attention_block = AttentionBlock(
shape.embed_dim, shape.num_heads, device, dtype
)
total_tokens = shape.batch_size * shape.max_seq_len
x_packed = torch.randn(
total_tokens,
shape.embed_dim,
device=device,
dtype=dtype,
)
cu_seq = torch.tensor(
[0, shape.max_seq_len, total_tokens], device=device, dtype=torch.int32
)
q, k, v = attention_block.get_varlen_qkv(x_packed)
torch.library.opcheck(
torch.ops.torch_attn._varlen_attn,
(q, k, v, cu_seq, cu_seq, shape.max_seq_len, shape.max_seq_len, False),
)
out, lse, rng_state = torch.ops.torch_attn._varlen_attn(
q, k, v, cu_seq, cu_seq, shape.max_seq_len, shape.max_seq_len, False
)
grad_out = torch.randn_like(out)
# we don't support double backward
# skipping test_autograd_registration, test_aot_dispatch_dynamic, test_aot_dispatch_static
torch.library.opcheck(
torch.ops.torch_attn._varlen_attn_backward,
(
grad_out,
q,
k,
v,
out,
lse,
cu_seq,
cu_seq,
shape.max_seq_len,
shape.max_seq_len,
False,
rng_state,
),
test_utils=["test_schema", "test_faketensor"],
)
@skipIfRocm(msg="ROCM does not support variable length attention")
@unittest.skipIf(
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention not supported"
)
@parametrize("dtype", [torch.bfloat16, torch.float16])
def test_custom_op_registration(self, device, dtype):
torch.manual_seed(42)
shape = VarlenShape(batch_size=2, max_seq_len=512, embed_dim=1024, num_heads=16)
attention_block = AttentionBlock(
shape.embed_dim, shape.num_heads, device, dtype
)
total_tokens = shape.batch_size * shape.max_seq_len
x_packed = torch.randn(
total_tokens,
shape.embed_dim,
device=device,
dtype=dtype,
requires_grad=True,
)
cu_seq = torch.tensor(
[0, shape.max_seq_len, total_tokens], device=device, dtype=torch.int32
)
compiled_forward = torch.compile(
attention_block.forward_varlen, backend="eager", fullgraph=True
)
with OpLoggingMode() as mode:
output = compiled_forward(
x_packed, cu_seq, shape.max_seq_len, is_causal=False
)
varlen_grad_out = torch.ones_like(output)
_ = torch.autograd.grad(
outputs=output,
inputs=x_packed,
grad_outputs=varlen_grad_out,
retain_graph=True,
create_graph=False,
allow_unused=False,
)[0]
called_ops = mode.called_ops
custom_ops_called = any(
"torch_attn._varlen_attn" in op for op in called_ops
) and any("torch_attn._varlen_attn_backward" in op for op in called_ops)
assert custom_ops_called
@skipIfRocm(msg="ROCM does not support variable length attention")
@unittest.skipIf(
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention not supported"
)
@ -156,14 +325,21 @@ class TestVarlenAttention(NNTestCase):
torch.manual_seed(42)
shape = VarlenShape(
batch_size=8, max_seq_len=2048, embed_dim=1024, num_heads=16
batch_size=2, max_seq_len=128, embed_dim=32, num_heads=4
)
attention_block = AttentionBlock(
shape.embed_dim, shape.num_heads, device, dtype
)
golden_attention_block = AttentionBlock(
shape.embed_dim, shape.num_heads, device, torch.float64
)
variable_length_batch_data = create_variable_length_batch(shape, device, dtype)
golden_variable_length_batch_data = create_variable_length_batch(
shape, device, torch.float64
)
varlen_output = attention_block.forward_varlen(
variable_length_batch_data["x_packed"],
@ -172,18 +348,89 @@ class TestVarlenAttention(NNTestCase):
is_causal=is_causal,
)
sdpa_output = attention_block.forward_sdpa(
variable_length_batch_data["x_padded"], is_causal=is_causal
variable_length_batch_data["x_padded"],
variable_length_batch_data["seq_lengths"],
is_causal=is_causal,
)
golden_sdpa_output = golden_attention_block.forward_sdpa(
golden_variable_length_batch_data["x_padded"],
golden_variable_length_batch_data["seq_lengths"],
is_causal=is_causal,
)
tolerances = default_tolerances[dtype]
start_idx = 0
for i, seq_len in enumerate(variable_length_batch_data["seq_lengths"]):
end_idx = start_idx + seq_len
varlen_seq = varlen_output[start_idx:end_idx]
sdpa_seq = sdpa_output[i, :seq_len]
golden_sdpa_seq = golden_sdpa_output[i, :seq_len]
fwd_atol = (
2 * (golden_sdpa_seq + 0.3 - 0.3 - golden_sdpa_seq).abs().max().item()
)
varlen_error = (varlen_seq - fwd_atol).abs().max().item()
sdpa_error = (sdpa_seq - fwd_atol).abs().max().item()
assert varlen_error <= sdpa_error + fwd_atol
start_idx = end_idx
varlen_grad_out = torch.ones_like(varlen_output)
sdpa_grad_out = torch.ones_like(sdpa_output)
golden_sdpa_grad_out = torch.ones_like(golden_sdpa_output)
start_idx = 0
for i, seq_len in enumerate(variable_length_batch_data["seq_lengths"]):
end_idx = start_idx + seq_len
sdpa_grad_out[i, :seq_len] = varlen_grad_out[start_idx:end_idx]
start_idx = end_idx
varlen_grad = torch.autograd.grad(
outputs=varlen_output,
inputs=variable_length_batch_data["x_packed"],
grad_outputs=varlen_grad_out,
retain_graph=True,
create_graph=False,
allow_unused=False,
)[0]
sdpa_grad = torch.autograd.grad(
outputs=sdpa_output,
inputs=variable_length_batch_data["x_padded"],
grad_outputs=sdpa_grad_out,
retain_graph=True,
create_graph=False,
allow_unused=False,
)[0]
golden_sdpa_grad = torch.autograd.grad(
outputs=golden_sdpa_output,
inputs=golden_variable_length_batch_data["x_padded"],
grad_outputs=golden_sdpa_grad_out,
retain_graph=True,
create_graph=False,
allow_unused=False,
)[0]
start_idx = 0
for i, seq_len in enumerate(variable_length_batch_data["seq_lengths"]):
end_idx = start_idx + seq_len
varlen_grad_seq = varlen_grad[start_idx:end_idx]
sdpa_grad_seq = sdpa_grad[i, :seq_len]
golden_sdpa_seq = golden_sdpa_grad[i, :seq_len]
fwd_atol = (
2 * (golden_sdpa_seq + 0.3 - 0.3 - golden_sdpa_seq).abs().max().item()
)
varlen_error = (varlen_grad_seq - fwd_atol).abs().max().item()
sdpa_error = (sdpa_grad_seq - fwd_atol).abs().max().item()
assert varlen_error <= sdpa_error + fwd_atol
torch.testing.assert_close(varlen_seq, sdpa_seq, **tolerances)
start_idx = end_idx

View File

@ -14,14 +14,11 @@ from torch.backends.cuda import (
SDPAParams,
)
from .varlen import varlen_attn
__all__: list[str] = [
"SDPBackend",
"sdpa_kernel",
"WARN_FOR_UNFUSED_KERNELS",
"varlen_attn",
]
# Note: [SDPA warnings]

View File

@ -7,7 +7,7 @@ that calls into the optimized Flash Attention kernels.
import logging
from functools import lru_cache
from typing import NamedTuple, Optional, Union
from typing import Any, NamedTuple, Optional, Union
import torch
@ -20,7 +20,7 @@ __all__ = ["varlen_attn", "AuxRequest"]
@lru_cache(maxsize=8)
def _should_use_cudnn(device_index: int) -> bool:
"""Cache device capability check to avoid repeated CUDA calls."""
return False
return True
class AuxRequest(NamedTuple):
@ -33,8 +33,7 @@ class AuxRequest(NamedTuple):
lse: bool = False
# import failures when I try to register as custom op
# @torch.library.custom_op("torch_nn_attention::_varlen_attn", mutates_args={})
@torch.library.custom_op("torch_attn::_varlen_attn", mutates_args={})
def _varlen_attn(
query: torch.Tensor,
key: torch.Tensor,
@ -44,7 +43,7 @@ def _varlen_attn(
max_q: int,
max_k: int,
is_causal: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Private custom op for variable-length attention.
@ -52,9 +51,9 @@ def _varlen_attn(
"""
use_cudnn = query.is_cuda and _should_use_cudnn(query.device.index)
if use_cudnn:
log.info("Using cuDNN backend for varlen_attn")
result = torch.ops.aten._cudnn_attention_forward(
query,
key,
@ -70,7 +69,7 @@ def _varlen_attn(
False, # return_debug_mask
)
# cuDNN returns: (output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask)
output, softmax_lse = result[0], result[1]
output, softmax_lse, rng_state, philox_offset = result[0], result[1], result[6], result[7]
else:
log.info("Using Flash Attention backend for varlen_attn")
output, softmax_lse, rng_state, _, _ = torch.ops.aten._flash_attention_forward(
@ -85,11 +84,16 @@ def _varlen_attn(
is_causal,
return_debug_mask=False,
)
philox_offset = torch.zeros((), dtype=torch.int64, device=query.device)
return output, softmax_lse
rng_state_ = torch.zeros(
(2,), dtype=torch.uint64, device=query.device
) # hardcoded since dropout is hardcoded to 0
return output, softmax_lse, rng_state_, philox_offset
# @_varlen_attn.register_fake
@_varlen_attn.register_fake
def _varlen_attn_fake(
query: torch.Tensor,
key: torch.Tensor,
@ -99,7 +103,7 @@ def _varlen_attn_fake(
max_q: int,
max_k: int,
is_causal: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Fake implementation for meta tensor computation and tracing.
@ -110,14 +114,24 @@ def _varlen_attn_fake(
# Output has same shape as query
output = torch.empty_like(query)
# For varlen path: logsumexp shape is (num_heads, total_q)
# For varlen path with cuDNN: logsumexp shape is (total_q, num_heads, 1)
total_q = query.size(0)
num_heads = query.size(1)
logsumexp = torch.empty(
(num_heads, total_q), dtype=torch.float, device=query.device
)
return output, logsumexp
use_cudnn = query.is_cuda and _should_use_cudnn(query.device.index)
if use_cudnn:
logsumexp = torch.empty(
(total_q, num_heads, 1), dtype=torch.float, device=query.device
)
else:
logsumexp = torch.empty(
(num_heads, total_q), dtype=torch.float, device=query.device
)
rng_state = torch.empty((2,), dtype=torch.uint64, device=query.device)
philox_offset = torch.zeros((), dtype=torch.int64, device=query.device)
return output, logsumexp, rng_state, philox_offset
def varlen_attn(
@ -191,9 +205,142 @@ def varlen_attn(
... query, key, value, cu_seq, cu_seq, max_len, max_len, is_causal=False
... )
"""
out, lse = _varlen_attn(
out, lse, _, _ = torch.ops.torch_attn._varlen_attn(
query, key, value, cu_seq_q, cu_seq_k, max_q, max_k, is_causal
)
if return_aux is not None and return_aux.lse:
return out, lse
return out
def _setup_context(ctx: Any, inputs: tuple[Any, ...], output: Any) -> None:
query, key, value, cu_seq_q, cu_seq_k, max_q, max_k, is_causal = inputs
out, lse, rng_state, philox_offset = output
ctx.save_for_backward(query, key, value, cu_seq_q, cu_seq_k, out, lse, rng_state, philox_offset)
ctx.max_q = max_q
ctx.max_k = max_k
ctx.is_causal = is_causal
@torch.library.custom_op("torch_attn::_varlen_attn_backward", mutates_args={})
def _varlen_attn_backward(
grad_out: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
out: torch.Tensor,
lse: torch.Tensor,
cu_seq_q: torch.Tensor,
cu_seq_k: torch.Tensor,
max_q: int,
max_k: int,
is_causal: bool,
rng_state: torch.Tensor,
philox_offset: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
unused = torch.empty(0, device=query.device)
use_cudnn = query.is_cuda and _should_use_cudnn(query.device.index)
if use_cudnn:
log.info("Using cuDNN backend for varlen_attn")
head_dim = query.size(-1)
scale = 1.0 / (head_dim ** 0.5)
dq, dk, dv = torch.ops.aten._cudnn_attention_backward(
grad_out = grad_out,
query = query,
key = key,
value = value,
out = out,
logsumexp = lse,
philox_seed = rng_state,
philox_offset = philox_offset,
attn_bias = None,
cum_seq_q = cu_seq_q,
cum_seq_k = cu_seq_k,
max_q = max_q,
max_k = max_k,
dropout_p = 0.0,
is_causal = is_causal,
# passing in scale doesn't change the value of the gradients
# scale=scale
)
else:
log.info("Using Flash Attention backend for varlen_attn")
dq, dk, dv = torch.ops.aten._flash_attention_backward(
grad_out,
query,
key,
value,
out,
lse,
cu_seq_q,
cu_seq_k,
max_q,
max_k,
0.0,
is_causal,
rng_state,
unused,
)
return dq, dk, dv
@_varlen_attn_backward.register_fake
def _varlen_attn_backward_fake(
grad_out: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
out: torch.Tensor,
lse: torch.Tensor,
cu_seq_q: torch.Tensor,
cu_seq_k: torch.Tensor,
max_q: int,
max_k: int,
is_causal: bool,
rng_state: torch.Tensor,
philox_offset: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Fake implementation for meta tensor computation and tracing.
"""
grad_query = torch.empty_like(query)
grad_key = torch.empty_like(key)
grad_value = torch.empty_like(value)
return grad_query, grad_key, grad_value
def _backward(
ctx: Any, grad_out: torch.Tensor, grad_lse: torch.Tensor, grad_rng: torch.Tensor, grad_philox_offset: torch.Tensor
) -> tuple[Optional[torch.Tensor], ...]:
query, key, value, cu_seq_q, cu_seq_k, out, lse, rng_state, philox_offset = ctx.saved_tensors
max_q = ctx.max_q
max_k = ctx.max_k
is_causal = ctx.is_causal
dq, dk, dv = torch.ops.torch_attn._varlen_attn_backward(
grad_out,
query,
key,
value,
out,
lse,
cu_seq_q,
cu_seq_k,
max_q,
max_k,
is_causal,
rng_state,
philox_offset
)
return dq, dk, dv, None, None, None, None, None, None
_varlen_attn.register_autograd(_backward, setup_context=_setup_context)