mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
varlen api (#164502)
**Summary** Today, the only way to have variable sequence length support in PyTorch attention is through nested tensors [here](https://docs.pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html#nestedtensor-and-dense-tensor-support). We also want to add an explicit lower-level API that provides variable sequence length support without padding/masking in SDPA. This PR builds out `varlen_attn`, the public API that users can call for the forward method, and `_varlen_attn`, the private API that calls into the Flash Attention/cuDNN backend. **Benchmarking** To benchmark, we compare runtime and TFLOPs against the current SDPA approach with padding. Settings: - 1 H100 machine - `batch_size=8`, `max_seq_len=2048`, `embed_dim=1024`, `num_heads=16` - dtype `torch.bfloat16` - `is_causal=False` - for variable length, we set sequences to be random multiples of 64 up to `max_seq_len` - 100 runs | | Variable Length API | SDPA | |--------|--------------------|----------| | Runtime | 0.21750560760498047 ms | 0.43171775817871094 ms | | TFLOPs | 231.812 | 320.840 | The sparsity is 0.453 which we can see matches the speedup we get from Varlen (approx 50%). TFLOPs remains around the same, with SDPA slightly larger due to potential higher overhead and total flops scaling with sequence length. **Testing** Run `python test/test_varlen_attention.py` for unit tests where we verify basic functionality and confirm numerical match between varlen outputs vs SDPA. **Next steps** Next steps from this PR (higher in the stack) include registering the private API `_varlen_attn` as a custom op, implementing backward support, and enabling cuDNN with correct numerics. (This stack builds on top of #162326) Pull Request resolved: https://github.com/pytorch/pytorch/pull/164502 Approved by: https://github.com/v0i0, https://github.com/drisspg
This commit is contained in:
committed by
PyTorch MergeBot
parent
2b71b62045
commit
78f5a1ec60
@ -23,6 +23,7 @@ Submodules
|
|||||||
flex_attention
|
flex_attention
|
||||||
bias
|
bias
|
||||||
experimental
|
experimental
|
||||||
|
varlen
|
||||||
|
|
||||||
.. toctree::
|
.. toctree::
|
||||||
:hidden:
|
:hidden:
|
||||||
@ -30,3 +31,4 @@ Submodules
|
|||||||
nn.attention.flex_attention
|
nn.attention.flex_attention
|
||||||
nn.attention.bias
|
nn.attention.bias
|
||||||
nn.attention.experimental
|
nn.attention.experimental
|
||||||
|
nn.attention.varlen
|
||||||
|
17
docs/source/nn.attention.varlen.md
Normal file
17
docs/source/nn.attention.varlen.md
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
```{eval-rst}
|
||||||
|
.. role:: hidden
|
||||||
|
:class: hidden-section
|
||||||
|
```
|
||||||
|
|
||||||
|
# torch.nn.attention.varlen
|
||||||
|
|
||||||
|
```{eval-rst}
|
||||||
|
.. automodule:: torch.nn.attention.varlen
|
||||||
|
.. currentmodule:: torch.nn.attention.varlen
|
||||||
|
```
|
||||||
|
```{eval-rst}
|
||||||
|
.. autofunction:: varlen_attn
|
||||||
|
```
|
||||||
|
```{eval-rst}
|
||||||
|
.. autoclass:: AuxRequest
|
||||||
|
```
|
195
test/test_varlen_attention.py
Normal file
195
test/test_varlen_attention.py
Normal file
@ -0,0 +1,195 @@
|
|||||||
|
# Owner(s): ["module: sdpa"]
|
||||||
|
import unittest
|
||||||
|
from collections import namedtuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.nn.attention import varlen_attn
|
||||||
|
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION
|
||||||
|
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||||
|
from torch.testing._internal.common_nn import NNTestCase
|
||||||
|
from torch.testing._internal.common_utils import parametrize, run_tests
|
||||||
|
|
||||||
|
|
||||||
|
VarlenShape = namedtuple(
|
||||||
|
"VarlenShape", ["batch_size", "max_seq_len", "embed_dim", "num_heads"]
|
||||||
|
)
|
||||||
|
|
||||||
|
default_tolerances = {
|
||||||
|
torch.float16: {"atol": 1e-1, "rtol": 1e-1},
|
||||||
|
torch.bfloat16: {"atol": 9e-2, "rtol": 5e-2},
|
||||||
|
torch.float32: {"atol": 1e-5, "rtol": 1.3e-6},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class AttentionBlock(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self, embed_dim: int, num_heads: int, device: torch.device, dtype: torch.dtype
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.embed_dim = embed_dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = embed_dim // num_heads
|
||||||
|
|
||||||
|
self.qkv_proj = nn.Linear(
|
||||||
|
embed_dim, 3 * embed_dim, bias=False, device=device, dtype=dtype
|
||||||
|
)
|
||||||
|
self.out_proj = nn.Linear(
|
||||||
|
embed_dim, embed_dim, bias=False, device=device, dtype=dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward_varlen(
|
||||||
|
self,
|
||||||
|
x_packed: torch.Tensor,
|
||||||
|
cu_seq: torch.Tensor,
|
||||||
|
max_len: int,
|
||||||
|
is_causal: bool = False,
|
||||||
|
):
|
||||||
|
qkv = self.qkv_proj(x_packed)
|
||||||
|
q, k, v = qkv.chunk(3, dim=-1)
|
||||||
|
|
||||||
|
q = q.view(-1, self.num_heads, self.head_dim)
|
||||||
|
k = k.view(-1, self.num_heads, self.head_dim)
|
||||||
|
v = v.view(-1, self.num_heads, self.head_dim)
|
||||||
|
|
||||||
|
attn_out = varlen_attn(
|
||||||
|
q, k, v, cu_seq, cu_seq, max_len, max_len, is_causal=is_causal
|
||||||
|
)
|
||||||
|
attn_out = attn_out.view(-1, self.embed_dim)
|
||||||
|
|
||||||
|
return self.out_proj(attn_out)
|
||||||
|
|
||||||
|
def forward_sdpa(self, x_padded: torch.Tensor, is_causal: bool = False):
|
||||||
|
batch_size, seq_len, _ = x_padded.shape
|
||||||
|
|
||||||
|
qkv = self.qkv_proj(x_padded)
|
||||||
|
q, k, v = qkv.chunk(3, dim=-1)
|
||||||
|
|
||||||
|
q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
attn_out = F.scaled_dot_product_attention(q, k, v, is_causal=is_causal)
|
||||||
|
attn_out = (
|
||||||
|
attn_out.transpose(1, 2)
|
||||||
|
.contiguous()
|
||||||
|
.view(batch_size, seq_len, self.embed_dim)
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.out_proj(attn_out)
|
||||||
|
|
||||||
|
|
||||||
|
def create_variable_length_batch(
|
||||||
|
shape: VarlenShape, device: torch.device, dtype: torch.dtype
|
||||||
|
):
|
||||||
|
seq_lengths = []
|
||||||
|
for _ in range(shape.batch_size):
|
||||||
|
length = torch.randint(1, shape.max_seq_len // 64 + 1, (1,)).item() * 64
|
||||||
|
seq_lengths.append(min(length, shape.max_seq_len))
|
||||||
|
|
||||||
|
seq_lengths = torch.tensor(seq_lengths, device=device)
|
||||||
|
total_tokens = seq_lengths.sum().item()
|
||||||
|
|
||||||
|
x_packed = torch.randn(total_tokens, shape.embed_dim, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
cu_seq = torch.zeros(shape.batch_size + 1, device=device, dtype=torch.int32)
|
||||||
|
cu_seq[1:] = seq_lengths.cumsum(0)
|
||||||
|
|
||||||
|
max_len = seq_lengths.max().item()
|
||||||
|
x_padded = torch.zeros(
|
||||||
|
shape.batch_size, max_len, shape.embed_dim, device=device, dtype=dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
start_idx = 0
|
||||||
|
for i, seq_len in enumerate(seq_lengths):
|
||||||
|
end_idx = start_idx + seq_len
|
||||||
|
x_padded[i, :seq_len] = x_packed[start_idx:end_idx]
|
||||||
|
start_idx = end_idx
|
||||||
|
|
||||||
|
return {
|
||||||
|
"seq_lengths": seq_lengths,
|
||||||
|
"cu_seq": cu_seq,
|
||||||
|
"x_packed": x_packed,
|
||||||
|
"x_padded": x_padded,
|
||||||
|
"max_len": max_len,
|
||||||
|
"total_tokens": total_tokens,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TestVarlenAttention(NNTestCase):
|
||||||
|
@unittest.skipIf(
|
||||||
|
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention not supported"
|
||||||
|
)
|
||||||
|
@parametrize("dtype", [torch.bfloat16, torch.float16])
|
||||||
|
def test_basic_functionality(self, device, dtype):
|
||||||
|
torch.manual_seed(42)
|
||||||
|
|
||||||
|
shape = VarlenShape(batch_size=2, max_seq_len=512, embed_dim=1024, num_heads=16)
|
||||||
|
|
||||||
|
attention_block = AttentionBlock(
|
||||||
|
shape.embed_dim, shape.num_heads, device, dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
total_tokens = shape.batch_size * shape.max_seq_len
|
||||||
|
x_packed = torch.randn(
|
||||||
|
total_tokens, shape.embed_dim, device=device, dtype=dtype
|
||||||
|
)
|
||||||
|
cu_seq = torch.tensor(
|
||||||
|
[0, shape.max_seq_len, total_tokens], device=device, dtype=torch.int32
|
||||||
|
)
|
||||||
|
|
||||||
|
output = attention_block.forward_varlen(
|
||||||
|
x_packed, cu_seq, shape.max_seq_len, is_causal=False
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(output.shape, (total_tokens, shape.embed_dim))
|
||||||
|
self.assertEqual(output.device, torch.device(device))
|
||||||
|
self.assertEqual(output.dtype, dtype)
|
||||||
|
|
||||||
|
@unittest.skipIf(
|
||||||
|
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Flash Attention not supported"
|
||||||
|
)
|
||||||
|
@parametrize("dtype", [torch.bfloat16, torch.float16])
|
||||||
|
@parametrize("is_causal", [False, True])
|
||||||
|
def test_varlen_vs_sdpa(self, device, dtype, is_causal):
|
||||||
|
torch.manual_seed(42)
|
||||||
|
|
||||||
|
shape = VarlenShape(
|
||||||
|
batch_size=8, max_seq_len=2048, embed_dim=1024, num_heads=16
|
||||||
|
)
|
||||||
|
|
||||||
|
attention_block = AttentionBlock(
|
||||||
|
shape.embed_dim, shape.num_heads, device, dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
variable_length_batch_data = create_variable_length_batch(shape, device, dtype)
|
||||||
|
|
||||||
|
varlen_output = attention_block.forward_varlen(
|
||||||
|
variable_length_batch_data["x_packed"],
|
||||||
|
variable_length_batch_data["cu_seq"],
|
||||||
|
variable_length_batch_data["max_len"],
|
||||||
|
is_causal=is_causal,
|
||||||
|
)
|
||||||
|
sdpa_output = attention_block.forward_sdpa(
|
||||||
|
variable_length_batch_data["x_padded"], is_causal=is_causal
|
||||||
|
)
|
||||||
|
|
||||||
|
tolerances = default_tolerances[dtype]
|
||||||
|
start_idx = 0
|
||||||
|
for i, seq_len in enumerate(variable_length_batch_data["seq_lengths"]):
|
||||||
|
end_idx = start_idx + seq_len
|
||||||
|
|
||||||
|
varlen_seq = varlen_output[start_idx:end_idx]
|
||||||
|
sdpa_seq = sdpa_output[i, :seq_len]
|
||||||
|
|
||||||
|
torch.testing.assert_close(varlen_seq, sdpa_seq, **tolerances)
|
||||||
|
start_idx = end_idx
|
||||||
|
|
||||||
|
|
||||||
|
device_types = ("cuda",)
|
||||||
|
|
||||||
|
instantiate_device_type_tests(TestVarlenAttention, globals(), only_for=device_types)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run_tests()
|
@ -14,8 +14,15 @@ from torch.backends.cuda import (
|
|||||||
SDPAParams,
|
SDPAParams,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from .varlen import varlen_attn
|
||||||
|
|
||||||
__all__: list[str] = ["SDPBackend", "sdpa_kernel", "WARN_FOR_UNFUSED_KERNELS"]
|
|
||||||
|
__all__: list[str] = [
|
||||||
|
"SDPBackend",
|
||||||
|
"sdpa_kernel",
|
||||||
|
"WARN_FOR_UNFUSED_KERNELS",
|
||||||
|
"varlen_attn",
|
||||||
|
]
|
||||||
|
|
||||||
# Note: [SDPA warnings]
|
# Note: [SDPA warnings]
|
||||||
# TODO: Consider using this for sdpa regardless of subclasses
|
# TODO: Consider using this for sdpa regardless of subclasses
|
||||||
|
199
torch/nn/attention/varlen.py
Normal file
199
torch/nn/attention/varlen.py
Normal file
@ -0,0 +1,199 @@
|
|||||||
|
"""
|
||||||
|
Variable-length attention implementation using Flash Attention.
|
||||||
|
|
||||||
|
This module provides a high-level Python interface for variable-length attention
|
||||||
|
that calls into the optimized Flash Attention kernels.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from functools import lru_cache
|
||||||
|
from typing import NamedTuple, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
__all__ = ["varlen_attn", "AuxRequest"]
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=8)
|
||||||
|
def _should_use_cudnn(device_index: int) -> bool:
|
||||||
|
"""Cache device capability check to avoid repeated CUDA calls."""
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class AuxRequest(NamedTuple):
|
||||||
|
"""
|
||||||
|
Request which auxiliary outputs to compute from varlen_attn.
|
||||||
|
|
||||||
|
Each field is a boolean indicating whether that auxiliary output should be computed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
lse: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
# import failures when I try to register as custom op
|
||||||
|
# @torch.library.custom_op("torch_nn_attention::_varlen_attn", mutates_args={})
|
||||||
|
def _varlen_attn(
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
cu_seq_q: torch.Tensor,
|
||||||
|
cu_seq_k: torch.Tensor,
|
||||||
|
max_q: int,
|
||||||
|
max_k: int,
|
||||||
|
is_causal: bool = False,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Private custom op for variable-length attention.
|
||||||
|
|
||||||
|
This is the internal implementation. Users should use the public varlen_attn function instead.
|
||||||
|
"""
|
||||||
|
|
||||||
|
use_cudnn = query.is_cuda and _should_use_cudnn(query.device.index)
|
||||||
|
|
||||||
|
if use_cudnn:
|
||||||
|
log.info("Using cuDNN backend for varlen_attn")
|
||||||
|
result = torch.ops.aten._cudnn_attention_forward(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
None, # attn_bias
|
||||||
|
cu_seq_q,
|
||||||
|
cu_seq_k,
|
||||||
|
max_q,
|
||||||
|
max_k,
|
||||||
|
True, # compute_log_sumexp
|
||||||
|
0.0, # dropout_p hardcoded to 0.0
|
||||||
|
is_causal,
|
||||||
|
False, # return_debug_mask
|
||||||
|
)
|
||||||
|
# cuDNN returns: (output, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask)
|
||||||
|
output, softmax_lse = result[0], result[1]
|
||||||
|
else:
|
||||||
|
log.info("Using Flash Attention backend for varlen_attn")
|
||||||
|
output, softmax_lse, rng_state, _, _ = torch.ops.aten._flash_attention_forward(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
cu_seq_q,
|
||||||
|
cu_seq_k,
|
||||||
|
max_q,
|
||||||
|
max_k,
|
||||||
|
0.0, # dropout_p hardcoded to 0.0
|
||||||
|
is_causal,
|
||||||
|
return_debug_mask=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return output, softmax_lse
|
||||||
|
|
||||||
|
|
||||||
|
# @_varlen_attn.register_fake
|
||||||
|
def _varlen_attn_fake(
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
cu_seq_q: torch.Tensor,
|
||||||
|
cu_seq_k: torch.Tensor,
|
||||||
|
max_q: int,
|
||||||
|
max_k: int,
|
||||||
|
is_causal: bool = False,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Fake implementation for meta tensor computation and tracing.
|
||||||
|
|
||||||
|
Based on the 3D varlen path from meta__flash_attention_forward:
|
||||||
|
- query shape: (total, num_heads, head_dim)
|
||||||
|
- logsumexp shape: (num_heads, total_q)
|
||||||
|
"""
|
||||||
|
# Output has same shape as query
|
||||||
|
output = torch.empty_like(query)
|
||||||
|
|
||||||
|
# For varlen path: logsumexp shape is (num_heads, total_q)
|
||||||
|
total_q = query.size(0)
|
||||||
|
num_heads = query.size(1)
|
||||||
|
logsumexp = torch.empty(
|
||||||
|
(num_heads, total_q), dtype=torch.float, device=query.device
|
||||||
|
)
|
||||||
|
|
||||||
|
return output, logsumexp
|
||||||
|
|
||||||
|
|
||||||
|
def varlen_attn(
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
cu_seq_q: torch.Tensor,
|
||||||
|
cu_seq_k: torch.Tensor,
|
||||||
|
max_q: int,
|
||||||
|
max_k: int,
|
||||||
|
is_causal: bool = False,
|
||||||
|
return_aux: Optional[AuxRequest] = None,
|
||||||
|
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
Compute variable-length attention using Flash Attention.
|
||||||
|
This function is similar to scaled_dot_product_attention but optimized for
|
||||||
|
variable-length sequences using cumulative sequence position tensors.
|
||||||
|
Args:
|
||||||
|
- query (Tensor): Query tensor; shape :math:`(T_q, H, D)`
|
||||||
|
- key (Tensor): Key tensor; shape :math:`(T_k, H, D)`
|
||||||
|
- value (Tensor): Value tensor; shape :math:`(T_k, H, D)`
|
||||||
|
- cu_seq_q (Tensor): Cumulative sequence positions for queries; shape :math:`(N+1,)`
|
||||||
|
- cu_seq_k (Tensor): Cumulative sequence positions for keys/values; shape :math:`(N+1,)`
|
||||||
|
- max_q (int): Maximum query sequence length in the batch.
|
||||||
|
- max_k (int): Maximum key/value sequence length in the batch.
|
||||||
|
- is_causal (bool, optional): If set to True, applies causal masking (default: False).
|
||||||
|
- return_aux (Optional[AuxRequest]): If not None and ``return_aux.lse`` is True, also returns the logsumexp tensor.
|
||||||
|
|
||||||
|
Shape legend:
|
||||||
|
- :math:`N`: Batch size
|
||||||
|
- :math:`T_q`: Total number of query tokens in the batch (sum of all query sequence lengths)
|
||||||
|
- :math:`T_k`: Total number of key/value tokens in the batch (sum of all key/value sequence lengths)
|
||||||
|
- :math:`H`: Number of attention heads
|
||||||
|
- :math:`D`: Head dimension
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- Tensor: Output tensor from attention computation
|
||||||
|
- If ``return_aux`` is not None and ``return_aux.lse`` is True, returns a tuple of Tensors:
|
||||||
|
(output, lse), where lse is the logsumexp
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
|
||||||
|
>>> batch_size, max_seq_len, embed_dim, num_heads = 2, 512, 1024, 16
|
||||||
|
>>> head_dim = embed_dim // num_heads
|
||||||
|
>>> seq_lengths = []
|
||||||
|
>>> for _ in range(batch_size):
|
||||||
|
... length = torch.randint(1, max_seq_len // 64 + 1, (1,)).item() * 64
|
||||||
|
... seq_lengths.append(min(length, max_seq_len))
|
||||||
|
>>> seq_lengths = torch.tensor(seq_lengths, device="cuda")
|
||||||
|
>>> total_tokens = seq_lengths.sum().item()
|
||||||
|
>>>
|
||||||
|
>>> # Create packed query, key, value tensors
|
||||||
|
>>> query = torch.randn(
|
||||||
|
... total_tokens, num_heads, head_dim, dtype=torch.float16, device="cuda"
|
||||||
|
... )
|
||||||
|
>>> key = torch.randn(
|
||||||
|
... total_tokens, num_heads, head_dim, dtype=torch.float16, device="cuda"
|
||||||
|
... )
|
||||||
|
>>> value = torch.randn(
|
||||||
|
... total_tokens, num_heads, head_dim, dtype=torch.float16, device="cuda"
|
||||||
|
... )
|
||||||
|
>>>
|
||||||
|
>>> # Build cumulative sequence tensor
|
||||||
|
>>> cu_seq = torch.zeros(batch_size + 1, device="cuda", dtype=torch.int32)
|
||||||
|
>>> cu_seq[1:] = seq_lengths.cumsum(0)
|
||||||
|
>>> max_len = seq_lengths.max().item()
|
||||||
|
>>>
|
||||||
|
>>> # Call varlen_attn
|
||||||
|
>>> output = varlen_attn(
|
||||||
|
... query, key, value, cu_seq, cu_seq, max_len, max_len, is_causal=False
|
||||||
|
... )
|
||||||
|
"""
|
||||||
|
out, lse = _varlen_attn(
|
||||||
|
query, key, value, cu_seq_q, cu_seq_k, max_q, max_k, is_causal
|
||||||
|
)
|
||||||
|
if return_aux is not None and return_aux.lse:
|
||||||
|
return out, lse
|
||||||
|
return out
|
Reference in New Issue
Block a user