mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Attention] FlashAttention MLA cudagraph support (#23958)
Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
@ -61,6 +61,16 @@ backend_configs = {
|
||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||
},
|
||||
specific_gpu_arch=(9, 0)),
|
||||
# FlashAttention MLA on Hopper
|
||||
"FlashAttentionMLA":
|
||||
BackendConfig(name="FlashAttentionMLA",
|
||||
env_vars={
|
||||
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA",
|
||||
},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_DECODE_ONLY",
|
||||
},
|
||||
specific_gpu_arch=(9, 0)),
|
||||
# Cutlass MLA on Blackwell
|
||||
"CutlassMLA":
|
||||
BackendConfig(
|
||||
@ -102,7 +112,7 @@ backend_configs = {
|
||||
test_params_full_cudagraph = []
|
||||
|
||||
# deepseek-ai/DeepSeek-V2-Lite with MLA
|
||||
MLA_backends = ["FlashMLA", "CutlassMLA"]
|
||||
MLA_backends = ["FlashMLA", "FlashAttentionMLA", "CutlassMLA"]
|
||||
for mla_backend in MLA_backends:
|
||||
test_params_full_cudagraph.append(
|
||||
pytest.param(
|
||||
|
@ -73,7 +73,6 @@ def create_and_prepopulate_kv_cache(
|
||||
kv_c_contexts: list[torch.Tensor],
|
||||
k_pe_contexts: list[torch.Tensor],
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
@ -87,7 +86,6 @@ def create_and_prepopulate_kv_cache(
|
||||
k_pe_contexts: List of key positional embedding context tensors
|
||||
for each sequence
|
||||
block_size: Size of each block
|
||||
num_kv_heads: Number of KV heads (should be 1 for MLA)
|
||||
head_size: Size of each head (latent dimension)
|
||||
dtype: Data type for the cache
|
||||
device: Device to create the cache on
|
||||
@ -285,8 +283,6 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
|
||||
query_lens = batch_spec.query_lens
|
||||
num_q_heads = vllm_config.model_config.get_num_attention_heads(
|
||||
vllm_config.parallel_config)
|
||||
num_kv_heads = vllm_config.model_config.get_num_kv_heads(
|
||||
vllm_config.parallel_config)
|
||||
head_size = vllm_config.model_config.get_head_size()
|
||||
dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype)
|
||||
block_size = vllm_config.cache_config.block_size
|
||||
@ -476,7 +472,6 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str):
|
||||
kv_c_contexts=kv_c_contexts,
|
||||
k_pe_contexts=k_pe_contexts,
|
||||
block_size=block_size,
|
||||
num_kv_heads=num_kv_heads,
|
||||
head_size=head_size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
|
@ -62,6 +62,16 @@ backend_configs = {
|
||||
"cudagraph_mode": "FULL_AND_PIECEWISE",
|
||||
},
|
||||
specific_gpu_arch=(9, 0)),
|
||||
# FlashAttention MLA on Hopper
|
||||
"FlashAttentionMLA":
|
||||
BackendConfig(name="FlashAttentionMLA",
|
||||
env_vars={
|
||||
"VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA",
|
||||
},
|
||||
comp_config={
|
||||
"cudagraph_mode": "FULL_DECODE_ONLY",
|
||||
},
|
||||
specific_gpu_arch=(9, 0)),
|
||||
# FA2
|
||||
"FA2":
|
||||
BackendConfig(name="FA2",
|
||||
|
@ -443,11 +443,13 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
self.metadata_cls = metadata_cls \
|
||||
if metadata_cls is not None else MLACommonMetadata
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.device = device
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
self.model_config = vllm_config.model_config
|
||||
cache_config = vllm_config.cache_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
cache_config = vllm_config.cache_config
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.device = device
|
||||
|
||||
self.num_heads = self.model_config.get_num_attention_heads(
|
||||
parallel_config)
|
||||
self.mla_dims = get_mla_dims(self.model_config)
|
||||
@ -608,10 +610,12 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
prefill.prefill_main = self._fi_prefill_main
|
||||
prefill.prefill_chunks = self._fi_prefill_chunks
|
||||
|
||||
def _build_decode(
|
||||
self, block_table_tensor: torch.Tensor, seq_lens_cpu: torch.Tensor,
|
||||
seq_lens_device: torch.Tensor, query_start_loc_cpu: torch.Tensor,
|
||||
query_start_loc_device: torch.Tensor) -> MLACommonDecodeMetadata:
|
||||
def _build_decode(self, block_table_tensor: torch.Tensor,
|
||||
seq_lens_cpu: torch.Tensor,
|
||||
seq_lens_device: torch.Tensor,
|
||||
query_start_loc_cpu: torch.Tensor,
|
||||
query_start_loc_device: torch.Tensor,
|
||||
num_decode_tokens: int) -> MLACommonDecodeMetadata:
|
||||
return MLACommonDecodeMetadata(
|
||||
block_table=block_table_tensor,
|
||||
seq_lens=seq_lens_device,
|
||||
@ -624,11 +628,12 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
Currently, only decode is supported for full cudagraphs with MLA.
|
||||
"""
|
||||
m = common_attn_metadata
|
||||
assert m.num_reqs == m.num_actual_tokens, \
|
||||
assert m.num_reqs <= (m.num_actual_tokens *
|
||||
self.reorder_batch_threshold), \
|
||||
"MLA only supports decode-only full CUDAGraph capture. " \
|
||||
"Make sure all cudagraph capture sizes <= max_num_seq."
|
||||
|
||||
assert m.max_query_len == 1 # decode-only
|
||||
assert m.max_query_len <= self.reorder_batch_threshold # decode only
|
||||
|
||||
return self.build(0, m)
|
||||
|
||||
@ -819,6 +824,7 @@ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
|
||||
seq_lens_device=seq_lens[:num_decodes],
|
||||
query_start_loc_cpu=query_start_loc_cpu[:num_decodes + 1],
|
||||
query_start_loc_device=query_start_loc[:num_decodes + 1],
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
)
|
||||
|
||||
attn_metadata = self.metadata_cls(
|
||||
|
@ -17,11 +17,16 @@ from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata,
|
||||
MLACommonMetadataBuilder)
|
||||
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.vllm_flash_attn import flash_attn_varlen_func, get_scheduler_metadata
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# NOTE(matt): This is an arbitrary number, copied from
|
||||
# woosuk's implementation in standard FlashAttention backend
|
||||
_DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH = 16
|
||||
|
||||
|
||||
class FlashAttnMLABackend(MLACommonBackend):
|
||||
|
||||
@ -48,6 +53,7 @@ class FlashAttnMLADecodeMetadata(MLACommonDecodeMetadata):
|
||||
max_query_len: int
|
||||
max_seq_len: int
|
||||
scheduler_metadata: Optional[torch.Tensor] = None
|
||||
max_num_splits: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -57,14 +63,41 @@ class FlashAttnMLAMetadata(MLACommonMetadata[FlashAttnMLADecodeMetadata]):
|
||||
|
||||
class FlashAttnMLAMetadataBuilder(
|
||||
MLACommonMetadataBuilder[FlashAttnMLAMetadata]):
|
||||
cudagraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.UNIFORM_BATCH
|
||||
|
||||
reorder_batch_threshold: ClassVar[int] = 512
|
||||
|
||||
def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
|
||||
vllm_config: VllmConfig, device: torch.device):
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device,
|
||||
FlashAttnMLAMetadata)
|
||||
self.max_num_splits = 0 # No upper bound on the number of splits.
|
||||
self.fa_aot_schedule = (get_flash_attn_version() == 3)
|
||||
|
||||
self.use_full_cuda_graph = \
|
||||
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
||||
|
||||
if self.use_full_cuda_graph and self.fa_aot_schedule:
|
||||
self.max_cudagraph_size = self.compilation_config.max_capture_size
|
||||
|
||||
if self.max_cudagraph_size > 992:
|
||||
# This condition derives from FA3's internal heuristic.
|
||||
# TODO(woosuk): Support larger cudagraph sizes.
|
||||
raise ValueError(
|
||||
"Capture size larger than 992 is not supported for "
|
||||
"full cuda graph.")
|
||||
|
||||
self.scheduler_metadata = torch.zeros(
|
||||
vllm_config.scheduler_config.max_num_seqs + 1,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
# When using cuda graph, we need to set the upper bound of the
|
||||
# number of splits so that large enough intermediate buffers are
|
||||
# pre-allocated during capture.
|
||||
self.max_num_splits = _DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH
|
||||
|
||||
def _schedule_decode(self, num_reqs, cu_query_lens, max_query_len, seqlens,
|
||||
max_seq_len, causal):
|
||||
if self.fa_aot_schedule:
|
||||
@ -81,14 +114,16 @@ class FlashAttnMLAMetadataBuilder(
|
||||
page_size=self.page_size,
|
||||
cu_seqlens_q=cu_query_lens,
|
||||
causal=causal,
|
||||
num_splits=self.max_num_splits,
|
||||
)
|
||||
return None
|
||||
|
||||
def _build_decode(
|
||||
self, block_table_tensor: torch.Tensor, seq_lens_cpu: torch.Tensor,
|
||||
seq_lens_device: torch.Tensor, query_start_loc_cpu: torch.Tensor,
|
||||
query_start_loc_device: torch.Tensor
|
||||
) -> FlashAttnMLADecodeMetadata:
|
||||
def _build_decode(self, block_table_tensor: torch.Tensor,
|
||||
seq_lens_cpu: torch.Tensor,
|
||||
seq_lens_device: torch.Tensor,
|
||||
query_start_loc_cpu: torch.Tensor,
|
||||
query_start_loc_device: torch.Tensor,
|
||||
num_decode_tokens: int) -> FlashAttnMLADecodeMetadata:
|
||||
query_lens_cpu = (query_start_loc_cpu[1:] - query_start_loc_cpu[:-1])
|
||||
max_query_len = query_lens_cpu.max().item()
|
||||
max_seq_len = seq_lens_cpu.max().item()
|
||||
@ -102,6 +137,29 @@ class FlashAttnMLAMetadataBuilder(
|
||||
causal=True,
|
||||
)
|
||||
|
||||
# For FA3 + full cudagraph
|
||||
max_num_splits = 0
|
||||
if self.use_full_cuda_graph and scheduler_metadata is not None:
|
||||
n = scheduler_metadata.shape[0]
|
||||
# Ensure the persistent buffer is large enough
|
||||
assert n <= self.scheduler_metadata.shape[0], \
|
||||
f"Scheduler metadata size {n} exceeds buffer size " + \
|
||||
f"{self.scheduler_metadata.shape[0]}"
|
||||
self.scheduler_metadata[:n] = scheduler_metadata
|
||||
# NOTE(woosuk): We should zero out the rest of the scheduler
|
||||
# metadata to guarantee the correctness. Otherwise, some thread
|
||||
# blocks may use the invalid scheduler metadata and overwrite the
|
||||
# output buffer.
|
||||
self.scheduler_metadata[n:] = 0
|
||||
scheduler_metadata = self.scheduler_metadata[:n]
|
||||
|
||||
if num_decode_tokens <= self.max_cudagraph_size:
|
||||
# NOTE(woosuk): Setting num_splits > 1 may increase the memory
|
||||
# usage, because the intermediate buffers of size [num_splits,
|
||||
# num_heads, num_tokens, head_size] are allocated. Therefore,
|
||||
# we only set num_splits when using cuda graphs.
|
||||
max_num_splits = self.max_num_splits
|
||||
|
||||
return FlashAttnMLADecodeMetadata(
|
||||
block_table=block_table_tensor,
|
||||
seq_lens=seq_lens_device,
|
||||
@ -109,6 +167,7 @@ class FlashAttnMLAMetadataBuilder(
|
||||
max_query_len=max_query_len,
|
||||
max_seq_len=max_seq_len,
|
||||
scheduler_metadata=scheduler_metadata,
|
||||
max_num_splits=max_num_splits,
|
||||
)
|
||||
|
||||
|
||||
@ -175,12 +234,17 @@ class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]):
|
||||
kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank]
|
||||
k_pe_cache = kv_c_and_k_pe_cache[..., self.kv_lora_rank:]
|
||||
|
||||
# NOTE(matt): During CUDA graph capture, max_query_len can be 0, but the
|
||||
# kernel uses this to calculate grid dimensions. Ensure it's at least 1
|
||||
# to prevent invalid grid configuration during graph capture.
|
||||
max_seqlen_q = max(attn_metadata.decode.max_query_len, 1)
|
||||
|
||||
o = flash_attn_varlen_func(
|
||||
q=q_pe,
|
||||
k=k_pe_cache.unsqueeze(-2), # Add head dim of 1
|
||||
v=kv_c_cache.unsqueeze(-2), # Add head dim of 1
|
||||
q_v=q_nope,
|
||||
max_seqlen_q=attn_metadata.decode.max_query_len,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
cu_seqlens_q=attn_metadata.decode.query_start_loc,
|
||||
max_seqlen_k=attn_metadata.decode.max_seq_len,
|
||||
seqused_k=attn_metadata.decode.seq_lens,
|
||||
@ -189,6 +253,7 @@ class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]):
|
||||
causal=True,
|
||||
fa_version=3, # only version 3 is supported
|
||||
scheduler_metadata=attn_metadata.decode.scheduler_metadata,
|
||||
num_splits=attn_metadata.decode.max_num_splits,
|
||||
)
|
||||
|
||||
return self._v_up_proj(o)
|
||||
|
@ -62,7 +62,6 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
super().__init__(kv_cache_spec, layer_names, vllm_config, device,
|
||||
FlashMLAMetadata)
|
||||
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.num_q_heads = vllm_config.model_config.get_num_attention_heads(
|
||||
vllm_config.parallel_config)
|
||||
|
||||
@ -85,10 +84,12 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
device=self.device,
|
||||
dtype=torch.int32)
|
||||
|
||||
def _build_decode(
|
||||
self, block_table_tensor: torch.Tensor, seq_lens_cpu: torch.Tensor,
|
||||
seq_lens_device: torch.Tensor, query_start_loc_cpu: torch.Tensor,
|
||||
query_start_loc_device: torch.Tensor) -> FlashMLADecodeMetadata:
|
||||
def _build_decode(self, block_table_tensor: torch.Tensor,
|
||||
seq_lens_cpu: torch.Tensor,
|
||||
seq_lens_device: torch.Tensor,
|
||||
query_start_loc_cpu: torch.Tensor,
|
||||
query_start_loc_device: torch.Tensor,
|
||||
num_decode_tokens: int) -> FlashMLADecodeMetadata:
|
||||
tile_scheduler_metadata, num_splits = \
|
||||
get_mla_metadata(
|
||||
seq_lens_device,
|
||||
|
@ -104,10 +104,12 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
|
||||
def _build_decode(
|
||||
self, block_table_tensor: torch.Tensor, seq_lens_cpu: torch.Tensor,
|
||||
seq_lens_device: torch.Tensor, query_start_loc_cpu: torch.Tensor,
|
||||
query_start_loc_device: torch.Tensor) -> AiterMLADecodeMetadata:
|
||||
def _build_decode(self, block_table_tensor: torch.Tensor,
|
||||
seq_lens_cpu: torch.Tensor,
|
||||
seq_lens_device: torch.Tensor,
|
||||
query_start_loc_cpu: torch.Tensor,
|
||||
query_start_loc_device: torch.Tensor,
|
||||
num_decode_tokens: int) -> AiterMLADecodeMetadata:
|
||||
page_size = self.kv_cache_spec.block_size
|
||||
block_table_bounds = (seq_lens_device + page_size - 1) // page_size
|
||||
device = self.device
|
||||
|
Reference in New Issue
Block a user