[Attention] FlashAttention MLA cudagraph support (#23958)

Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
Matthew Bonanni
2025-09-08 15:05:26 -07:00
committed by GitHub
parent 41183c1fe0
commit 620db1fc58
7 changed files with 118 additions and 29 deletions

View File

@ -61,6 +61,16 @@ backend_configs = {
"cudagraph_mode": "FULL_AND_PIECEWISE",
},
specific_gpu_arch=(9, 0)),
# FlashAttention MLA on Hopper
"FlashAttentionMLA":
BackendConfig(name="FlashAttentionMLA",
env_vars={
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA",
},
comp_config={
"cudagraph_mode": "FULL_DECODE_ONLY",
},
specific_gpu_arch=(9, 0)),
# Cutlass MLA on Blackwell
"CutlassMLA":
BackendConfig(
@ -102,7 +112,7 @@ backend_configs = {
test_params_full_cudagraph = []
# deepseek-ai/DeepSeek-V2-Lite with MLA
MLA_backends = ["FlashMLA", "CutlassMLA"]
MLA_backends = ["FlashMLA", "FlashAttentionMLA", "CutlassMLA"]
for mla_backend in MLA_backends:
test_params_full_cudagraph.append(
pytest.param(

View File

@ -73,7 +73,6 @@ def create_and_prepopulate_kv_cache(
kv_c_contexts: list[torch.Tensor],
k_pe_contexts: list[torch.Tensor],
block_size: int,
num_kv_heads: int,
head_size: int,
dtype: torch.dtype,
device: torch.device,
@ -87,7 +86,6 @@ def create_and_prepopulate_kv_cache(
k_pe_contexts: List of key positional embedding context tensors
for each sequence
block_size: Size of each block
num_kv_heads: Number of KV heads (should be 1 for MLA)
head_size: Size of each head (latent dimension)
dtype: Data type for the cache
device: Device to create the cache on
@ -285,8 +283,6 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
query_lens = batch_spec.query_lens
num_q_heads = vllm_config.model_config.get_num_attention_heads(
vllm_config.parallel_config)
num_kv_heads = vllm_config.model_config.get_num_kv_heads(
vllm_config.parallel_config)
head_size = vllm_config.model_config.get_head_size()
dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype)
block_size = vllm_config.cache_config.block_size
@ -476,7 +472,6 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
kv_c_contexts=kv_c_contexts,
k_pe_contexts=k_pe_contexts,
block_size=block_size,
num_kv_heads=num_kv_heads,
head_size=head_size,
dtype=dtype,
device=device,

View File

@ -62,6 +62,16 @@ backend_configs = {
"cudagraph_mode": "FULL_AND_PIECEWISE",
},
specific_gpu_arch=(9, 0)),
# FlashAttention MLA on Hopper
"FlashAttentionMLA":
BackendConfig(name="FlashAttentionMLA",
env_vars={
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA",
},
comp_config={
"cudagraph_mode": "FULL_DECODE_ONLY",
},
specific_gpu_arch=(9, 0)),
# FA2
"FA2":
BackendConfig(name="FA2",

View File

@ -443,11 +443,13 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
self.metadata_cls = metadata_cls \
if metadata_cls is not None else MLACommonMetadata
self.kv_cache_spec = kv_cache_spec
self.device = device
scheduler_config = vllm_config.scheduler_config
self.model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
parallel_config = vllm_config.parallel_config
cache_config = vllm_config.cache_config
self.compilation_config = vllm_config.compilation_config
self.device = device
self.num_heads = self.model_config.get_num_attention_heads(
parallel_config)
self.mla_dims = get_mla_dims(self.model_config)
@ -608,10 +610,12 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
prefill.prefill_main = self._fi_prefill_main
prefill.prefill_chunks = self._fi_prefill_chunks
def _build_decode(
self, block_table_tensor: torch.Tensor, seq_lens_cpu: torch.Tensor,
seq_lens_device: torch.Tensor, query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor) -> MLACommonDecodeMetadata:
def _build_decode(self, block_table_tensor: torch.Tensor,
seq_lens_cpu: torch.Tensor,
seq_lens_device: torch.Tensor,
query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor,
num_decode_tokens: int) -> MLACommonDecodeMetadata:
return MLACommonDecodeMetadata(
block_table=block_table_tensor,
seq_lens=seq_lens_device,
@ -624,11 +628,12 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
Currently, only decode is supported for full cudagraphs with MLA.
"""
m = common_attn_metadata
assert m.num_reqs == m.num_actual_tokens, \
assert m.num_reqs <= (m.num_actual_tokens *
self.reorder_batch_threshold), \
"MLA only supports decode-only full CUDAGraph capture. " \
"Make sure all cudagraph capture sizes <= max_num_seq."
assert m.max_query_len == 1 # decode-only
assert m.max_query_len <= self.reorder_batch_threshold # decode only
return self.build(0, m)
@ -819,6 +824,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
seq_lens_device=seq_lens[:num_decodes],
query_start_loc_cpu=query_start_loc_cpu[:num_decodes + 1],
query_start_loc_device=query_start_loc[:num_decodes + 1],
num_decode_tokens=num_decode_tokens,
)
attn_metadata = self.metadata_cls(

View File

@ -17,11 +17,16 @@ from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder)
from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.vllm_flash_attn import flash_attn_varlen_func, get_scheduler_metadata
logger = init_logger(__name__)
# NOTE(matt): This is an arbitrary number, copied from
# woosuk's implementation in standard FlashAttention backend
_DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH = 16
class FlashAttnMLABackend(MLACommonBackend):
@ -48,6 +53,7 @@ class FlashAttnMLADecodeMetadata(MLACommonDecodeMetadata):
max_query_len: int
max_seq_len: int
scheduler_metadata: Optional[torch.Tensor] = None
max_num_splits: int = 0
@dataclass
@ -57,14 +63,41 @@ class FlashAttnMLAMetadata(MLACommonMetadata[FlashAttnMLADecodeMetadata]):
class FlashAttnMLAMetadataBuilder(
MLACommonMetadataBuilder[FlashAttnMLAMetadata]):
cudagraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.UNIFORM_BATCH
reorder_batch_threshold: ClassVar[int] = 512
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
vllm_config: VllmConfig, device: torch.device):
super().__init__(kv_cache_spec, layer_names, vllm_config, device,
FlashAttnMLAMetadata)
self.max_num_splits = 0 # No upper bound on the number of splits.
self.fa_aot_schedule = (get_flash_attn_version() == 3)
self.use_full_cuda_graph = \
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
if self.use_full_cuda_graph and self.fa_aot_schedule:
self.max_cudagraph_size = self.compilation_config.max_capture_size
if self.max_cudagraph_size > 992:
# This condition derives from FA3's internal heuristic.
# TODO(woosuk): Support larger cudagraph sizes.
raise ValueError(
"Capture size larger than 992 is not supported for "
"full cuda graph.")
self.scheduler_metadata = torch.zeros(
vllm_config.scheduler_config.max_num_seqs + 1,
dtype=torch.int32,
device=self.device,
)
# When using cuda graph, we need to set the upper bound of the
# number of splits so that large enough intermediate buffers are
# pre-allocated during capture.
self.max_num_splits = _DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH
def _schedule_decode(self, num_reqs, cu_query_lens, max_query_len, seqlens,
max_seq_len, causal):
if self.fa_aot_schedule:
@ -81,14 +114,16 @@ class FlashAttnMLAMetadataBuilder(
page_size=self.page_size,
cu_seqlens_q=cu_query_lens,
causal=causal,
num_splits=self.max_num_splits,
)
return None
def _build_decode(
self, block_table_tensor: torch.Tensor, seq_lens_cpu: torch.Tensor,
seq_lens_device: torch.Tensor, query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor
) -> FlashAttnMLADecodeMetadata:
def _build_decode(self, block_table_tensor: torch.Tensor,
seq_lens_cpu: torch.Tensor,
seq_lens_device: torch.Tensor,
query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor,
num_decode_tokens: int) -> FlashAttnMLADecodeMetadata:
query_lens_cpu = (query_start_loc_cpu[1:] - query_start_loc_cpu[:-1])
max_query_len = query_lens_cpu.max().item()
max_seq_len = seq_lens_cpu.max().item()
@ -102,6 +137,29 @@ class FlashAttnMLAMetadataBuilder(
causal=True,
)
# For FA3 + full cudagraph
max_num_splits = 0
if self.use_full_cuda_graph and scheduler_metadata is not None:
n = scheduler_metadata.shape[0]
# Ensure the persistent buffer is large enough
assert n <= self.scheduler_metadata.shape[0], \
f"Scheduler metadata size {n} exceeds buffer size " + \
f"{self.scheduler_metadata.shape[0]}"
self.scheduler_metadata[:n] = scheduler_metadata
# NOTE(woosuk): We should zero out the rest of the scheduler
# metadata to guarantee the correctness. Otherwise, some thread
# blocks may use the invalid scheduler metadata and overwrite the
# output buffer.
self.scheduler_metadata[n:] = 0
scheduler_metadata = self.scheduler_metadata[:n]
if num_decode_tokens <= self.max_cudagraph_size:
# NOTE(woosuk): Setting num_splits > 1 may increase the memory
# usage, because the intermediate buffers of size [num_splits,
# num_heads, num_tokens, head_size] are allocated. Therefore,
# we only set num_splits when using cuda graphs.
max_num_splits = self.max_num_splits
return FlashAttnMLADecodeMetadata(
block_table=block_table_tensor,
seq_lens=seq_lens_device,
@ -109,6 +167,7 @@ class FlashAttnMLAMetadataBuilder(
max_query_len=max_query_len,
max_seq_len=max_seq_len,
scheduler_metadata=scheduler_metadata,
max_num_splits=max_num_splits,
)
@ -175,12 +234,17 @@ class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]):
kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank]
k_pe_cache = kv_c_and_k_pe_cache[..., self.kv_lora_rank:]
# NOTE(matt): During CUDA graph capture, max_query_len can be 0, but the
# kernel uses this to calculate grid dimensions. Ensure it's at least 1
# to prevent invalid grid configuration during graph capture.
max_seqlen_q = max(attn_metadata.decode.max_query_len, 1)
o = flash_attn_varlen_func(
q=q_pe,
k=k_pe_cache.unsqueeze(-2), # Add head dim of 1
v=kv_c_cache.unsqueeze(-2), # Add head dim of 1
q_v=q_nope,
max_seqlen_q=attn_metadata.decode.max_query_len,
max_seqlen_q=max_seqlen_q,
cu_seqlens_q=attn_metadata.decode.query_start_loc,
max_seqlen_k=attn_metadata.decode.max_seq_len,
seqused_k=attn_metadata.decode.seq_lens,
@ -189,6 +253,7 @@ class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]):
causal=True,
fa_version=3, # only version 3 is supported
scheduler_metadata=attn_metadata.decode.scheduler_metadata,
num_splits=attn_metadata.decode.max_num_splits,
)
return self._v_up_proj(o)

View File

@ -62,7 +62,6 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
super().__init__(kv_cache_spec, layer_names, vllm_config, device,
FlashMLAMetadata)
self.compilation_config = vllm_config.compilation_config
self.num_q_heads = vllm_config.model_config.get_num_attention_heads(
vllm_config.parallel_config)
@ -85,10 +84,12 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
device=self.device,
dtype=torch.int32)
def _build_decode(
self, block_table_tensor: torch.Tensor, seq_lens_cpu: torch.Tensor,
seq_lens_device: torch.Tensor, query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor) -> FlashMLADecodeMetadata:
def _build_decode(self, block_table_tensor: torch.Tensor,
seq_lens_cpu: torch.Tensor,
seq_lens_device: torch.Tensor,
query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor,
num_decode_tokens: int) -> FlashMLADecodeMetadata:
tile_scheduler_metadata, num_splits = \
get_mla_metadata(
seq_lens_device,

View File

@ -104,10 +104,12 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
dtype=torch.int32,
device=device)
def _build_decode(
self, block_table_tensor: torch.Tensor, seq_lens_cpu: torch.Tensor,
seq_lens_device: torch.Tensor, query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor) -> AiterMLADecodeMetadata:
def _build_decode(self, block_table_tensor: torch.Tensor,
seq_lens_cpu: torch.Tensor,
seq_lens_device: torch.Tensor,
query_start_loc_cpu: torch.Tensor,
query_start_loc_device: torch.Tensor,
num_decode_tokens: int) -> AiterMLADecodeMetadata:
page_size = self.kv_cache_spec.block_size
block_table_bounds = (seq_lens_device + page_size - 1) // page_size
device = self.device