Add attention sink in attention backends (#22320)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>

Co-authored-by: LiuXiaoxuanPKU <lilyliupku@gmail.com>
Co-authored-by: simon-mo <xmo@berkeley.edu>
Co-authored-by: Chen Zhang <zhangch99@outlook.com>
Co-authored-by: Hongxia Yang <62075498+hongxiayang@users.noreply.github.com>
Co-authored-by: Minseok Lee <47620120+minseokl@users.noreply.github.com>
Co-authored-by: Yongye Zhu <zyy1102000@gmail.com>
This commit is contained in:
Woosuk Kwon
2025-08-05 22:37:21 -07:00
committed by GitHub
parent dd16bdc798
commit 6e20924350
7 changed files with 176 additions and 45 deletions

View File

@ -28,6 +28,7 @@ def kernel_paged_attention_2d(
query_ptr, # [num_tokens, num_query_heads, head_size]
key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x]
value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size]
sink_ptr, # [num_query_heads]
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
seq_lens_ptr, # [num_seqs]
alibi_slopes_ptr, # [num_query_heads]
@ -95,7 +96,17 @@ def kernel_paged_attention_2d(
block_table_offset = seq_idx * block_table_stride
M = tl.full([num_queries_per_kv_padded], float("-inf"), dtype=tl.float32)
if sink_ptr is None:
M = tl.full([num_queries_per_kv_padded],
float("-inf"),
dtype=tl.float32)
else:
M = tl.load(
sink_ptr + query_head_idx,
mask=head_mask,
other=float("-inf"),
).to(dtype=tl.float32)
L = tl.full([num_queries_per_kv_padded], 1.0, dtype=tl.float32)
acc = tl.zeros([num_queries_per_kv_padded, HEAD_SIZE_PADDED],
dtype=tl.float32)
@ -223,6 +234,8 @@ def chunked_prefill_paged_decode(
alibi_slopes=None,
sliding_window=None,
sm_scale=None,
# Optional tensor for sinks
sinks=None,
):
if sm_scale is None:
@ -253,6 +266,7 @@ def chunked_prefill_paged_decode(
sliding_window=sliding_window,
sm_scale=sm_scale,
skip_decode=True,
sinks=sinks,
)
block_size = value_cache.shape[3]
@ -281,11 +295,17 @@ def chunked_prefill_paged_decode(
num_queries_per_kv_padded = max(triton.next_power_of_2(num_queries_per_kv),
16)
use_custom = use_rocm_custom_paged_attention(query.dtype, head_size,
block_size,
num_queries_per_kv,
max_seq_len, sliding_window,
kv_cache_dtype, alibi_slopes)
use_custom = use_rocm_custom_paged_attention(
query.dtype,
head_size,
block_size,
num_queries_per_kv,
max_seq_len,
sliding_window,
kv_cache_dtype,
alibi_slopes,
sinks,
)
if use_custom:
_PARTITION_SIZE_ROCM = 256
max_num_partitions = ((max_seq_len + _PARTITION_SIZE_ROCM - 1) //
@ -334,6 +354,7 @@ def chunked_prefill_paged_decode(
query_ptr=query,
key_cache_ptr=key_cache,
value_cache_ptr=value_cache,
sink_ptr=sinks,
block_tables_ptr=block_table,
seq_lens_ptr=seq_lens,
alibi_slopes_ptr=alibi_slopes,

View File

@ -38,6 +38,7 @@ def _fwd_kernel(Q,
V,
K_cache,
V_cache,
sink_ptr,
B_Loc,
sm_scale,
k_scale,
@ -126,7 +127,15 @@ def _fwd_kernel(Q,
other=0.0) # [M,D]
# initialize pointer to m and l
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
if sink_ptr is None:
m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
else:
m_i = tl.load(
sink_ptr + tl.full([BLOCK_M], cur_head, dtype=tl.int64),
mask=(offs_m < cur_batch_query_len),
other=float("-inf"),
).to(dtype=tl.float32)
l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL_PADDED], dtype=tl.float32) # [M,D]
@ -732,7 +741,8 @@ def context_attention_fwd(q,
alibi_slopes=None,
sliding_window=None,
sm_scale=None,
skip_decode=False):
skip_decode=False,
sinks=None):
q_dtype_is_f32 = q.dtype is torch.float32
@ -781,6 +791,7 @@ def context_attention_fwd(q,
sliding_window = 0
if alibi_slopes is not None:
assert sinks is None, "Sinks arg is not supported with alibi"
# need to reduce num. blocks when using fp32
# due to increased use of GPU shared memory
# if q.dtype is torch.float32:
@ -843,7 +854,7 @@ def context_attention_fwd(q,
max_seq_len = 0 if max_seq_len is None else max_seq_len
extra_kargs = {}
if current_platform.is_rocm():
extra_kargs = {"kpack": 2, "waves_per_eu": 2}
extra_kargs = {"kpack": 1, "waves_per_eu": 2}
grid = lambda META: (batch, head,
triton.cdiv(max_input_len, META["BLOCK_M"]))
@ -853,6 +864,7 @@ def context_attention_fwd(q,
v,
k_cache,
v_cache,
sinks,
b_loc,
sm_scale,
k_scale,

View File

@ -52,6 +52,7 @@ def kernel_unified_attention_2d(
query_ptr, # [num_tokens, num_query_heads, head_size]
key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
sink_ptr, # [num_query_heads]
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
seq_lens_ptr, # [num_seqs]
alibi_slopes_ptr, # [num_query_heads]
@ -131,7 +132,15 @@ def kernel_unified_attention_2d(
block_table_offset = seq_idx * block_table_stride
M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
if sink_ptr is None:
M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
else:
M = tl.load(
sink_ptr + query_offset_1,
mask=query_mask_1,
other=float("-inf"),
).to(dtype=tl.float32)
L = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32)
@ -292,6 +301,7 @@ def kernel_unified_attention_3d(
query_ptr, # [num_tokens, num_query_heads, head_size]
key_cache_ptr, # [num_blks, num_kv_heads, head_size // x, blk_size, x]
value_cache_ptr, # [num_blks, num_kv_heads, head_size, blk_size]
sink_ptr, # [num_query_heads]
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
seq_lens_ptr, # [num_seqs]
alibi_slopes_ptr, # [num_query_heads]
@ -383,7 +393,15 @@ def kernel_unified_attention_3d(
block_table_offset = seq_idx * block_table_stride
M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
if sink_ptr is None or segm_idx != 0:
M = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
else:
M = tl.load(
sink_ptr + query_offset_1,
mask=query_mask_1,
other=float("-inf"),
).to(dtype=tl.float32)
L = tl.full([BLOCK_M], 1.0, dtype=tl.float32)
acc = tl.zeros([BLOCK_M, HEAD_SIZE_PADDED], dtype=tl.float32)
@ -627,6 +645,8 @@ def unified_attention(
v_descale,
alibi_slopes=None,
qq_bias=None,
# Optional tensor for sinks
sinks=None,
):
assert causal, "Only causal attention is supported"
assert q_descale is None, "Q scales not supported"
@ -635,6 +655,10 @@ def unified_attention(
assert q.element_size() >= 2 or block_size >= 32, \
"Block size must be at least 32 for fp8"
if sinks is not None:
assert sinks.shape[0] == q.shape[1], \
"Sinks must be num_query_heads size"
use_alibi_slopes = alibi_slopes is not None
use_qq_bias = qq_bias is not None
@ -669,6 +693,7 @@ def unified_attention(
query_ptr=q,
key_cache_ptr=k,
value_cache_ptr=v,
sink_ptr=sinks,
block_tables_ptr=block_table,
seq_lens_ptr=seqused_k,
alibi_slopes_ptr=alibi_slopes,
@ -741,6 +766,7 @@ def unified_attention(
query_ptr=q,
key_cache_ptr=k,
value_cache_ptr=v,
sink_ptr=sinks,
block_tables_ptr=block_table,
seq_lens_ptr=seqused_k,
alibi_slopes_ptr=alibi_slopes,

View File

@ -17,6 +17,7 @@ if TYPE_CHECKING:
LD_LIBRARY_PATH: Optional[str] = None
VLLM_USE_TRITON_FLASH_ATTN: bool = True
VLLM_V1_USE_PREFILL_DECODE_ATTENTION: bool = False
VLLM_USE_AITER_UNIFIED_ATTENTION: bool = False
VLLM_FLASH_ATTN_VERSION: Optional[int] = None
LOCAL_RANK: int = 0
CUDA_VISIBLE_DEVICES: Optional[str] = None
@ -151,6 +152,8 @@ if TYPE_CHECKING:
VLLM_LOOPBACK_IP: str = ""
VLLM_ALLOW_CHUNKED_LOCAL_ATTN_WITH_HYBRID_KV_CACHE: bool = False
VLLM_ENABLE_RESPONSES_API_STORE: bool = False
VLLM_USE_TRTLLM_CONTEXT_ATTENTION: bool = False
VLLM_USE_TRTLLM_DECODE_ATTENTION: bool = False
def get_default_cache_root():
@ -326,6 +329,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
(os.getenv("VLLM_V1_USE_PREFILL_DECODE_ATTENTION", "False").lower() in
("true", "1")),
# Use AITER triton unified attention for V1 attention
"VLLM_USE_AITER_UNIFIED_ATTENTION":
lambda:
(os.getenv("VLLM_USE_AITER_UNIFIED_ATTENTION", "False").lower() in
("true", "1")),
# Force vllm to use a specific flash-attention version (2 or 3), only valid
# when using the flash-attention backend.
"VLLM_FLASH_ATTN_VERSION":
@ -1022,9 +1031,13 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_CUDNN_PREFILL":
lambda: bool(int(os.getenv("VLLM_USE_CUDNN_PREFILL", "0"))),
# If set to 1, use the TRTLLM Attention backend in flashinfer.
"VLLM_USE_TRTLLM_ATTENTION":
lambda: os.getenv("VLLM_USE_TRTLLM_ATTENTION", None),
# If set to 1, use the TRTLLM Context Attention backend in flashinfer.
"VLLM_USE_TRTLLM_CONTEXT_ATTENTION":
lambda: bool(int(os.getenv("VLLM_USE_TRTLLM_CONTEXT_ATTENTION", "0"))),
# If set to 1, use the TRTLLM Decode Attention backend in flashinfer.
"VLLM_USE_TRTLLM_DECODE_ATTENTION":
lambda: bool(int(os.getenv("VLLM_USE_TRTLLM_DECODE_ATTENTION", "0"))),
# Controls garbage collection during CUDA graph capture.
# If set to 0 (default), enables GC freezing to speed up capture time.

View File

@ -373,6 +373,7 @@ class FlashAttentionImpl(AttentionImpl):
logits_soft_cap: Optional[float] = None,
attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
sinks: Optional[torch.Tensor] = None,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
@ -410,6 +411,14 @@ class FlashAttentionImpl(AttentionImpl):
raise NotImplementedError(
"FlashAttention does not support fp8 kv-cache on this device.")
self.sinks = sinks
if self.sinks is not None:
assert self.vllm_flash_attn_version == 3, (
"Sinks are only supported in FlashAttention 3")
assert self.sinks.shape[0] == num_heads, (
"Sinks must have the same number of heads as the number of "
"heads in the layer")
def forward(
self,
layer: torch.nn.Module,
@ -534,6 +543,7 @@ class FlashAttentionImpl(AttentionImpl):
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
num_splits=attn_metadata.max_num_splits,
s_aux=self.sinks,
)
return output

View File

@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with PagedAttention and Triton prefix prefill."""
from dataclasses import dataclass
from functools import cache
from typing import ClassVar, Optional
import torch
@ -13,7 +14,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
from vllm.attention.ops.chunked_prefill_paged_decode import (
chunked_prefill_paged_decode)
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.attention.ops.triton_unified_attention import unified_attention
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.platforms import current_platform
@ -193,6 +193,15 @@ class TritonAttentionBackend(AttentionBackend):
return TritonAttentionMetadataBuilder
@cache
def use_aiter_unified_attention() -> bool:
"""Check if aiter unified attention should be used."""
# VLLM_ROCM_USE_AITER_MHA needs to set to 0 as well as it is set
# to 1 as default
return envs.VLLM_ROCM_USE_AITER \
and envs.VLLM_USE_AITER_UNIFIED_ATTENTION
class TritonAttentionImpl(AttentionImpl):
def __init__(
@ -207,6 +216,7 @@ class TritonAttentionImpl(AttentionImpl):
logits_soft_cap: Optional[float] = None,
attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[int] = None,
sinks: Optional[torch.Tensor] = None,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
@ -240,6 +250,29 @@ class TritonAttentionImpl(AttentionImpl):
self.force_prefill_decode_attn = \
envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION
if not self.force_prefill_decode_attn:
# If not using prefill decode attention, we use the Triton
# unified attention implementation.
if use_aiter_unified_attention():
logger.info_once(
"Using aiter unified attention for TritonAttentionImpl")
from aiter.ops.triton.unified_attention import (
unified_attention)
self.unified_attention = unified_attention
else:
logger.info_once(
"Using vllm unified attention for TritonAttentionImpl")
from vllm.attention.ops.triton_unified_attention import (
unified_attention)
self.unified_attention = unified_attention
self.sinks = sinks
if sinks is not None:
assert sinks.shape[0] == num_heads, (
"Sinks must have the same number of heads as the number of "
f"heads in the layer. Sinks shape: {sinks.shape}, "
f"num_heads: {num_heads}.")
def forward(
self,
layer: torch.nn.Module,
@ -342,28 +375,31 @@ class TritonAttentionImpl(AttentionImpl):
if use_prefill_decode_attn:
# Compute attention and update output up to `num_actual_tokens`.
chunked_prefill_paged_decode(query=query[:num_actual_tokens],
key=key[:num_actual_tokens],
value=value[:num_actual_tokens],
output=output[:num_actual_tokens],
kv_cache_dtype=self.kv_cache_dtype,
key_cache=key_cache,
value_cache=value_cache,
block_table=block_table,
query_start_loc=cu_seqlens_q,
seq_lens=seqused_k,
max_seq_len=max_seqlen_k,
max_query_len=max_seqlen_q,
k_scale=layer._k_scale,
v_scale=layer._v_scale,
alibi_slopes=self.alibi_slopes,
sliding_window=self.sliding_window[0],
sm_scale=self.scale)
chunked_prefill_paged_decode(
query=query[:num_actual_tokens],
key=key[:num_actual_tokens],
value=value[:num_actual_tokens],
output=output[:num_actual_tokens],
kv_cache_dtype=self.kv_cache_dtype,
key_cache=key_cache,
value_cache=value_cache,
block_table=block_table,
query_start_loc=cu_seqlens_q,
seq_lens=seqused_k,
max_seq_len=max_seqlen_k,
max_query_len=max_seqlen_q,
k_scale=layer._k_scale,
v_scale=layer._v_scale,
alibi_slopes=self.alibi_slopes,
sliding_window=self.sliding_window[0],
sm_scale=self.scale,
sinks=self.sinks,
)
else:
descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])
unified_attention(
self.unified_attention(
q=query[:num_actual_tokens],
k=key_cache,
v=value_cache,
@ -381,6 +417,7 @@ class TritonAttentionImpl(AttentionImpl):
q_descale=None, # Not supported
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
sinks=self.sinks,
)
return output

View File

@ -254,7 +254,11 @@ def get_kv_cache_layout():
# Override with format specified by the user.
cache_layout = envs.VLLM_KV_CACHE_LAYOUT
if cache_layout is None:
cache_layout = get_kv_connector_cache_layout()
if (envs.VLLM_USE_TRTLLM_CONTEXT_ATTENTION
or envs.VLLM_USE_TRTLLM_DECODE_ATTENTION):
cache_layout = "HND"
else:
cache_layout = get_kv_connector_cache_layout()
else:
logger.info_once("`VLLM_KV_CACHE_LAYOUT` environment variable " \
"detected. Setting KV cache layout to %s.", cache_layout)
@ -272,7 +276,9 @@ def set_kv_cache_layout(cache_layout: str):
class PerLayerParameters:
"""
Currently, FlashInfer backend only support models in which all layers share
the same values for the following hyperparameters.
the same values for the following hyperparameters. Should not be used for
trtllm-gen backend since it supports different values for the following
hyperparameters.
"""
window_left: int
@ -310,7 +316,8 @@ def get_per_layer_parameters(
def infer_global_hyperparameters(
per_layer_params: dict[str, PerLayerParameters]) -> PerLayerParameters:
"""
Currently, FlashInfer backend only support models in which all layers share
Currently, FlashInfer backend other than trtllm-gen
only support models in which all layers share
the same values for the following hyperparameters:
- `window_left`
- `logits_soft_cap`
@ -324,15 +331,20 @@ def infer_global_hyperparameters(
param_sets = list(per_layer_params.values())
global_params = param_sets[0]
for params in param_sets:
if params.window_left != global_params.window_left:
raise ValueError(
"Window left is not the same for all layers. One potential fix "
"is to set disable_sliding_window=True")
assert params == global_params, (
"FlashInfer backend currently only supports models in which all "
"layers share the same values for the following hyperparameters: "
"`window_left`, `logits_soft_cap`, `sm_scale`.")
# trtllm attention doesn't need global hyper params so disable the check
if (not envs.VLLM_USE_TRTLLM_CONTEXT_ATTENTION
and not envs.VLLM_USE_TRTLLM_DECODE_ATTENTION):
for params in param_sets:
if params.window_left != global_params.window_left:
raise ValueError(
"Window left is not the same for all layers. " \
"One potential fix is to set disable_sliding_window=True")
assert params == global_params, (
"FlashInfer backend currently only supports models in which all"
"layers share the same values "
"for the following hyperparameters:"
"`window_left`, `logits_soft_cap`, `sm_scale`.")
return global_params