diff --git a/docs/source/nn.attention.rst b/docs/source/nn.attention.rst index 120535d00259..8e7e6b0a762a 100644 --- a/docs/source/nn.attention.rst +++ b/docs/source/nn.attention.rst @@ -23,6 +23,7 @@ Submodules flex_attention bias experimental + varlen .. toctree:: :hidden: @@ -30,3 +31,4 @@ Submodules nn.attention.flex_attention nn.attention.bias nn.attention.experimental + nn.attention.varlen diff --git a/docs/source/nn.attention.varlen.md b/docs/source/nn.attention.varlen.md new file mode 100644 index 000000000000..df91e1d968e6 --- /dev/null +++ b/docs/source/nn.attention.varlen.md @@ -0,0 +1,17 @@ +```{eval-rst} +.. role:: hidden + :class: hidden-section +``` + +# torch.nn.attention.varlen + +```{eval-rst} +.. automodule:: torch.nn.attention.varlen +.. currentmodule:: torch.nn.attention.varlen +``` +```{eval-rst} +.. autofunction:: varlen_attn +``` +```{eval-rst} +.. autoclass:: AuxRequest +``` diff --git a/test/test_varlen_attention.py b/test/test_varlen_attention.py new file mode 100644 index 000000000000..f249adf21a52 --- /dev/null +++ b/test/test_varlen_attention.py @@ -0,0 +1,195 @@ +# Owner(s): ["module: sdpa"] +import unittest +from collections import namedtuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.attention import varlen_attn +from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION +from torch.testing._internal.common_device_type import instantiate_device_type_tests +from torch.testing._internal.common_nn import NNTestCase +from torch.testing._internal.common_utils import parametrize, run_tests + + +VarlenShape = namedtuple( + "VarlenShape", ["batch_size", "max_seq_len", "embed_dim", "num_heads"] +) + +default_tolerances = { + torch.float16: {"atol": 1e-1, "rtol": 1e-1}, + torch.bfloat16: {"atol": 9e-2, "rtol": 5e-2}, + torch.float32: {"atol": 1e-5, "rtol": 1.3e-6}, +} + + +class AttentionBlock(nn.Module): + def __init__( + self, embed_dim: int, num_heads: int, device: torch.device, dtype: torch.dtype + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.head_dim = embed_dim // num_heads + + self.qkv_proj = nn.Linear( + embed_dim, 3 * embed_dim, bias=False, device=device, dtype=dtype + ) + self.out_proj = nn.Linear( + embed_dim, embed_dim, bias=False, device=device, dtype=dtype + ) + + def forward_varlen( + self, + x_packed: torch.Tensor, + cu_seq: torch.Tensor, + max_len: int, + is_causal: bool = False, + ): + qkv = self.qkv_proj(x_packed) + q, k, v = qkv.chunk(3, dim=-1) + + q = q.view(-1, self.num_heads, self.head_dim) + k = k.view(-1, self.num_heads, self.head_dim) + v = v.view(-1, self.num_heads, self.head_dim) + + attn_out = varlen_attn( + q, k, v, cu_seq, cu_seq, max_len, max_len, is_causal=is_causal + ) + attn_out = attn_out.view(-1, self.embed_dim) + + return self.out_proj(attn_out) + + def forward_sdpa(self, x_padded: torch.Tensor, is_causal: bool = False): + batch_size, seq_len, _ = x_padded.shape + + qkv = self.qkv_proj(x_padded) + q, k, v = qkv.chunk(3, dim=-1) + + q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) + k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) + v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) + + attn_out = F.scaled_dot_product_attention(q, k, v, is_causal=is_causal) + attn_out = ( + attn_out.transpose(1, 2) + .contiguous() + .view(batch_size, seq_len, self.embed_dim) + ) + + return self.out_proj(attn_out) + + +def create_variable_length_batch( + shape: VarlenShape, device: torch.device, dtype: torch.dtype +): + seq_lengths = [] + for _ in range(shape.batch_size): + length = torch.randint(1, shape.max_seq_len // 64 + 1, (1,)).item() * 64 + seq_lengths.append(min(length, shape.max_seq_len)) + + seq_lengths = torch.tensor(seq_lengths, device=device) + total_tokens = seq_lengths.sum().item() + + x_packed = torch.randn(total_tokens, shape.embed_dim, device=device, dtype=dtype) + + cu_seq = torch.zeros(shape.batch_size + 1, device=device, dtype=torch.int32) + cu_seq[1:] = seq_lengths.cumsum(0) + + max_len = seq_lengths.max().item() + x_padded = torch.zeros( + shape.batch_size, max_len, shape.embed_dim, device=device, dtype=dtype + ) + + start_idx = 0 + for i, seq_len in enumerate(seq_lengths): + end_idx = start_idx + seq_len + x_padded[i, :seq_len] = x_packed[start_idx:end_idx] + start_idx = end_idx + + return { + "seq_lengths": seq_lengths, + "cu_seq": cu_seq, + "x_packed": x_packed, + "x_padded": x_padded, + "max_len": max_len, + "total_tokens": total_tokens, + } + + +class TestVarlenAttention(NNTestCase): + @unittest.skipIf( + not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention not supported" + ) + @parametrize("dtype", [torch.bfloat16, torch.float16]) + def test_basic_functionality(self, device, dtype): + torch.manual_seed(42) + + shape = VarlenShape(batch_size=2, max_seq_len=512, embed_dim=1024, num_heads=16) + + attention_block = AttentionBlock( + shape.embed_dim, shape.num_heads, device, dtype + ) + + total_tokens = shape.batch_size * shape.max_seq_len + x_packed = torch.randn( + total_tokens, shape.embed_dim, device=device, dtype=dtype + ) + cu_seq = torch.tensor( + [0, shape.max_seq_len, total_tokens], device=device, dtype=torch.int32 + ) + + output = attention_block.forward_varlen( + x_packed, cu_seq, shape.max_seq_len, is_causal=False + ) + + self.assertEqual(output.shape, (total_tokens, shape.embed_dim)) + self.assertEqual(output.device, torch.device(device)) + self.assertEqual(output.dtype, dtype) + + @unittest.skipIf( + not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention not supported" + ) + @parametrize("dtype", [torch.bfloat16, torch.float16]) + @parametrize("is_causal", [False, True]) + def test_varlen_vs_sdpa(self, device, dtype, is_causal): + torch.manual_seed(42) + + shape = VarlenShape( + batch_size=8, max_seq_len=2048, embed_dim=1024, num_heads=16 + ) + + attention_block = AttentionBlock( + shape.embed_dim, shape.num_heads, device, dtype + ) + + variable_length_batch_data = create_variable_length_batch(shape, device, dtype) + + varlen_output = attention_block.forward_varlen( + variable_length_batch_data["x_packed"], + variable_length_batch_data["cu_seq"], + variable_length_batch_data["max_len"], + is_causal=is_causal, + ) + sdpa_output = attention_block.forward_sdpa( + variable_length_batch_data["x_padded"], is_causal=is_causal + ) + + tolerances = default_tolerances[dtype] + start_idx = 0 + for i, seq_len in enumerate(variable_length_batch_data["seq_lengths"]): + end_idx = start_idx + seq_len + + varlen_seq = varlen_output[start_idx:end_idx] + sdpa_seq = sdpa_output[i, :seq_len] + + torch.testing.assert_close(varlen_seq, sdpa_seq, **tolerances) + start_idx = end_idx + + +device_types = ("cuda",) + +instantiate_device_type_tests(TestVarlenAttention, globals(), only_for=device_types) + +if __name__ == "__main__": + run_tests() diff --git a/torch/nn/attention/__init__.py b/torch/nn/attention/__init__.py index efdd7daa0d2a..e1adc664e20f 100644 --- a/torch/nn/attention/__init__.py +++ b/torch/nn/attention/__init__.py @@ -14,8 +14,15 @@ from torch.backends.cuda import ( SDPAParams, ) +from .varlen import varlen_attn -__all__: list[str] = ["SDPBackend", "sdpa_kernel", "WARN_FOR_UNFUSED_KERNELS"] + +__all__: list[str] = [ + "SDPBackend", + "sdpa_kernel", + "WARN_FOR_UNFUSED_KERNELS", + "varlen_attn", +] # Note: [SDPA warnings] # TODO: Consider using this for sdpa regardless of subclasses diff --git a/torch/nn/attention/varlen.py b/torch/nn/attention/varlen.py new file mode 100644 index 000000000000..7234dd5e7912 --- /dev/null +++ b/torch/nn/attention/varlen.py @@ -0,0 +1,199 @@ +""" +Variable-length attention implementation using Flash Attention. + +This module provides a high-level Python interface for variable-length attention +that calls into the optimized Flash Attention kernels. +""" + +import logging +from functools import lru_cache +from typing import NamedTuple, Optional, Union + +import torch + + +log = logging.getLogger(__name__) + +__all__ = ["varlen_attn", "AuxRequest"] + + +@lru_cache(maxsize=8) +def _should_use_cudnn(device_index: int) -> bool: + """Cache device capability check to avoid repeated CUDA calls.""" + return False + + +class AuxRequest(NamedTuple): + """ + Request which auxiliary outputs to compute from varlen_attn. + + Each field is a boolean indicating whether that auxiliary output should be computed. + """ + + lse: bool = False + + +# import failures when I try to register as custom op +# @torch.library.custom_op("torch_nn_attention::_varlen_attn", mutates_args={}) +def _varlen_attn( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cu_seq_q: torch.Tensor, + cu_seq_k: torch.Tensor, + max_q: int, + max_k: int, + is_causal: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Private custom op for variable-length attention. + + This is the internal implementation. Users should use the public varlen_attn function instead. + """ + + use_cudnn = query.is_cuda and _should_use_cudnn(query.device.index) + + if use_cudnn: + log.info("Using cuDNN backend for varlen_attn") + result = torch.ops.aten._cudnn_attention_forward( + query, + key, + value, + None, # attn_bias + cu_seq_q, + cu_seq_k, + max_q, + max_k, + True, # compute_log_sumexp + 0.0, # dropout_p hardcoded to 0.0 + is_causal, + False, # return_debug_mask + ) + # cuDNN returns: (output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask) + output, softmax_lse = result[0], result[1] + else: + log.info("Using Flash Attention backend for varlen_attn") + output, softmax_lse, rng_state, _, _ = torch.ops.aten._flash_attention_forward( + query, + key, + value, + cu_seq_q, + cu_seq_k, + max_q, + max_k, + 0.0, # dropout_p hardcoded to 0.0 + is_causal, + return_debug_mask=False, + ) + + return output, softmax_lse + + +# @_varlen_attn.register_fake +def _varlen_attn_fake( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cu_seq_q: torch.Tensor, + cu_seq_k: torch.Tensor, + max_q: int, + max_k: int, + is_causal: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Fake implementation for meta tensor computation and tracing. + + Based on the 3D varlen path from meta__flash_attention_forward: + - query shape: (total, num_heads, head_dim) + - logsumexp shape: (num_heads, total_q) + """ + # Output has same shape as query + output = torch.empty_like(query) + + # For varlen path: logsumexp shape is (num_heads, total_q) + total_q = query.size(0) + num_heads = query.size(1) + logsumexp = torch.empty( + (num_heads, total_q), dtype=torch.float, device=query.device + ) + + return output, logsumexp + + +def varlen_attn( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + cu_seq_q: torch.Tensor, + cu_seq_k: torch.Tensor, + max_q: int, + max_k: int, + is_causal: bool = False, + return_aux: Optional[AuxRequest] = None, +) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + """ + Compute variable-length attention using Flash Attention. + This function is similar to scaled_dot_product_attention but optimized for + variable-length sequences using cumulative sequence position tensors. + Args: + - query (Tensor): Query tensor; shape :math:`(T_q, H, D)` + - key (Tensor): Key tensor; shape :math:`(T_k, H, D)` + - value (Tensor): Value tensor; shape :math:`(T_k, H, D)` + - cu_seq_q (Tensor): Cumulative sequence positions for queries; shape :math:`(N+1,)` + - cu_seq_k (Tensor): Cumulative sequence positions for keys/values; shape :math:`(N+1,)` + - max_q (int): Maximum query sequence length in the batch. + - max_k (int): Maximum key/value sequence length in the batch. + - is_causal (bool, optional): If set to True, applies causal masking (default: False). + - return_aux (Optional[AuxRequest]): If not None and ``return_aux.lse`` is True, also returns the logsumexp tensor. + + Shape legend: + - :math:`N`: Batch size + - :math:`T_q`: Total number of query tokens in the batch (sum of all query sequence lengths) + - :math:`T_k`: Total number of key/value tokens in the batch (sum of all key/value sequence lengths) + - :math:`H`: Number of attention heads + - :math:`D`: Head dimension + + Returns: + - Tensor: Output tensor from attention computation + - If ``return_aux`` is not None and ``return_aux.lse`` is True, returns a tuple of Tensors: + (output, lse), where lse is the logsumexp + + Example:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> batch_size, max_seq_len, embed_dim, num_heads = 2, 512, 1024, 16 + >>> head_dim = embed_dim // num_heads + >>> seq_lengths = [] + >>> for _ in range(batch_size): + ... length = torch.randint(1, max_seq_len // 64 + 1, (1,)).item() * 64 + ... seq_lengths.append(min(length, max_seq_len)) + >>> seq_lengths = torch.tensor(seq_lengths, device="cuda") + >>> total_tokens = seq_lengths.sum().item() + >>> + >>> # Create packed query, key, value tensors + >>> query = torch.randn( + ... total_tokens, num_heads, head_dim, dtype=torch.float16, device="cuda" + ... ) + >>> key = torch.randn( + ... total_tokens, num_heads, head_dim, dtype=torch.float16, device="cuda" + ... ) + >>> value = torch.randn( + ... total_tokens, num_heads, head_dim, dtype=torch.float16, device="cuda" + ... ) + >>> + >>> # Build cumulative sequence tensor + >>> cu_seq = torch.zeros(batch_size + 1, device="cuda", dtype=torch.int32) + >>> cu_seq[1:] = seq_lengths.cumsum(0) + >>> max_len = seq_lengths.max().item() + >>> + >>> # Call varlen_attn + >>> output = varlen_attn( + ... query, key, value, cu_seq, cu_seq, max_len, max_len, is_causal=False + ... ) + """ + out, lse = _varlen_attn( + query, key, value, cu_seq_q, cu_seq_k, max_q, max_k, is_causal + ) + if return_aux is not None and return_aux.lse: + return out, lse + return out