mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
310 lines
11 KiB
Python
310 lines
11 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
from dataclasses import dataclass
|
|
from typing import ClassVar
|
|
|
|
import torch
|
|
|
|
from vllm import envs
|
|
from vllm.attention.backends.abstract import (
|
|
AttentionLayer,
|
|
AttentionType,
|
|
is_quantized_kv_cache,
|
|
)
|
|
from vllm.attention.utils.fa_utils import (
|
|
flash_attn_supports_mla,
|
|
get_flash_attn_version,
|
|
)
|
|
from vllm.config import VllmConfig
|
|
from vllm.logger import init_logger
|
|
from vllm.model_executor.layers.batch_invariant import (
|
|
vllm_is_batch_invariant,
|
|
)
|
|
from vllm.v1.attention.backends.mla.common import (
|
|
MLACommonBackend,
|
|
MLACommonDecodeMetadata,
|
|
MLACommonImpl,
|
|
MLACommonMetadata,
|
|
MLACommonMetadataBuilder,
|
|
QueryLenSupport,
|
|
)
|
|
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
|
from vllm.vllm_flash_attn import flash_attn_varlen_func, get_scheduler_metadata
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
class FlashAttnMLABackend(MLACommonBackend):
|
|
@staticmethod
|
|
def get_name() -> str:
|
|
return "FLASH_ATTN_MLA"
|
|
|
|
@staticmethod
|
|
def get_metadata_cls() -> type["FlashAttnMLAMetadata"]:
|
|
return FlashAttnMLAMetadata
|
|
|
|
@staticmethod
|
|
def get_builder_cls() -> type["FlashAttnMLAMetadataBuilder"]:
|
|
return FlashAttnMLAMetadataBuilder
|
|
|
|
@staticmethod
|
|
def get_impl_cls() -> type["FlashAttnMLAImpl"]:
|
|
return FlashAttnMLAImpl
|
|
|
|
|
|
@dataclass
|
|
class FlashAttnMLADecodeMetadata(MLACommonDecodeMetadata):
|
|
query_start_loc: torch.Tensor
|
|
max_query_len: int
|
|
max_seq_len: int
|
|
scheduler_metadata: torch.Tensor | None = None
|
|
max_num_splits: int = 0
|
|
|
|
|
|
@dataclass
|
|
class FlashAttnMLAMetadata(MLACommonMetadata[FlashAttnMLADecodeMetadata]):
|
|
pass
|
|
|
|
|
|
class FlashAttnMLAMetadataBuilder(MLACommonMetadataBuilder[FlashAttnMLAMetadata]):
|
|
cudagraph_support: ClassVar[AttentionCGSupport] = AttentionCGSupport.UNIFORM_BATCH
|
|
query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.VARLEN
|
|
reorder_batch_threshold: int = 512 # process small prefills with decode pathway
|
|
|
|
def __init__(
|
|
self,
|
|
kv_cache_spec: AttentionSpec,
|
|
layer_names: list[str],
|
|
vllm_config: VllmConfig,
|
|
device: torch.device,
|
|
):
|
|
super().__init__(
|
|
kv_cache_spec, layer_names, vllm_config, device, FlashAttnMLAMetadata
|
|
)
|
|
self.max_num_splits = 0 # No upper bound on the number of splits.
|
|
self.fa_aot_schedule = get_flash_attn_version() == 3
|
|
|
|
self.use_full_cuda_graph = (
|
|
self.compilation_config.cudagraph_mode.has_full_cudagraphs()
|
|
)
|
|
|
|
if self.use_full_cuda_graph and self.fa_aot_schedule:
|
|
self.max_cudagraph_size = self.compilation_config.max_capture_size
|
|
|
|
if self.max_cudagraph_size > 992:
|
|
# This condition derives from FA3's internal heuristic.
|
|
# TODO(woosuk): Support larger cudagraph sizes.
|
|
raise ValueError(
|
|
"Capture size larger than 992 is not supported for full cuda graph."
|
|
)
|
|
|
|
self.scheduler_metadata = torch.zeros(
|
|
vllm_config.scheduler_config.max_num_seqs + 1,
|
|
dtype=torch.int32,
|
|
device=self.device,
|
|
)
|
|
# When using cuda graph, we need to set the upper bound of the
|
|
# number of splits so that large enough intermediate buffers are
|
|
# pre-allocated during capture.
|
|
self.max_num_splits = envs.VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH
|
|
|
|
if vllm_is_batch_invariant():
|
|
self.max_num_splits = 1
|
|
|
|
def _schedule_decode(
|
|
self, num_reqs, cu_query_lens, max_query_len, seqlens, max_seq_len, causal
|
|
):
|
|
if self.fa_aot_schedule:
|
|
return get_scheduler_metadata(
|
|
batch_size=num_reqs,
|
|
max_seqlen_q=max_query_len,
|
|
max_seqlen_k=max_seq_len,
|
|
num_heads_q=self.num_heads * self.dcp_world_size,
|
|
num_heads_kv=1,
|
|
headdim=self.mla_dims.qk_rope_head_dim,
|
|
cache_seqlens=seqlens,
|
|
qkv_dtype=self.kv_cache_spec.dtype,
|
|
headdim_v=self.mla_dims.kv_lora_rank,
|
|
page_size=self.page_size,
|
|
cu_seqlens_q=cu_query_lens,
|
|
causal=causal,
|
|
num_splits=self.max_num_splits,
|
|
)
|
|
return None
|
|
|
|
def _build_decode(
|
|
self,
|
|
block_table_tensor: torch.Tensor,
|
|
seq_lens_cpu: torch.Tensor,
|
|
seq_lens_device: torch.Tensor,
|
|
query_start_loc_cpu: torch.Tensor,
|
|
query_start_loc_device: torch.Tensor,
|
|
num_decode_tokens: int,
|
|
dcp_tot_seq_lens_device: torch.Tensor | None,
|
|
) -> FlashAttnMLADecodeMetadata:
|
|
query_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
|
max_query_len = query_lens_cpu.max().item()
|
|
max_seq_len = seq_lens_device.max().item()
|
|
|
|
scheduler_metadata = self._schedule_decode(
|
|
num_reqs=seq_lens_cpu.numel(),
|
|
cu_query_lens=query_start_loc_device,
|
|
max_query_len=max_query_len,
|
|
seqlens=seq_lens_device,
|
|
max_seq_len=max_seq_len,
|
|
causal=True,
|
|
)
|
|
|
|
# For FA3 + full cudagraph
|
|
max_num_splits = 0
|
|
if self.use_full_cuda_graph and scheduler_metadata is not None:
|
|
n = scheduler_metadata.shape[0]
|
|
# Ensure the persistent buffer is large enough
|
|
assert n <= self.scheduler_metadata.shape[0], (
|
|
f"Scheduler metadata size {n} exceeds buffer size "
|
|
+ f"{self.scheduler_metadata.shape[0]}"
|
|
)
|
|
self.scheduler_metadata[:n] = scheduler_metadata
|
|
# NOTE(woosuk): We should zero out the rest of the scheduler
|
|
# metadata to guarantee the correctness. Otherwise, some thread
|
|
# blocks may use the invalid scheduler metadata and overwrite the
|
|
# output buffer.
|
|
self.scheduler_metadata[n:] = 0
|
|
scheduler_metadata = self.scheduler_metadata[:n]
|
|
|
|
if num_decode_tokens <= self.max_cudagraph_size:
|
|
# NOTE(woosuk): Setting num_splits > 1 may increase the memory
|
|
# usage, because the intermediate buffers of size [num_splits,
|
|
# num_heads, num_tokens, head_size] are allocated. Therefore,
|
|
# we only set num_splits when using cuda graphs.
|
|
max_num_splits = self.max_num_splits
|
|
|
|
if vllm_is_batch_invariant():
|
|
max_num_splits = 1
|
|
|
|
metadata = FlashAttnMLADecodeMetadata(
|
|
block_table=block_table_tensor,
|
|
seq_lens=seq_lens_device,
|
|
query_start_loc=query_start_loc_device,
|
|
max_query_len=max_query_len,
|
|
max_seq_len=max_seq_len,
|
|
scheduler_metadata=scheduler_metadata,
|
|
max_num_splits=max_num_splits,
|
|
dcp_tot_seq_lens=dcp_tot_seq_lens_device,
|
|
)
|
|
return metadata
|
|
|
|
|
|
class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]):
|
|
can_return_lse_for_decode: bool = True
|
|
|
|
def __init__(
|
|
self,
|
|
num_heads: int,
|
|
head_size: int,
|
|
scale: float,
|
|
num_kv_heads: int,
|
|
alibi_slopes: list[float] | None,
|
|
sliding_window: int | None,
|
|
kv_cache_dtype: str,
|
|
logits_soft_cap: float | None,
|
|
attn_type: str,
|
|
kv_sharing_target_layer_name: str | None,
|
|
# MLA Specific Arguments
|
|
**mla_args,
|
|
) -> None:
|
|
super().__init__(
|
|
num_heads,
|
|
head_size,
|
|
scale,
|
|
num_kv_heads,
|
|
alibi_slopes,
|
|
sliding_window,
|
|
kv_cache_dtype,
|
|
logits_soft_cap,
|
|
attn_type,
|
|
kv_sharing_target_layer_name,
|
|
**mla_args,
|
|
)
|
|
|
|
assert flash_attn_supports_mla(), "FlashAttnMLA is not supported on this device"
|
|
|
|
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
|
|
if any(unsupported_features):
|
|
raise NotImplementedError(
|
|
"FlashAttnMLAImpl does not support one of the following: "
|
|
"alibi_slopes, sliding_window, logits_soft_cap"
|
|
)
|
|
|
|
if attn_type != AttentionType.DECODER:
|
|
raise NotImplementedError(
|
|
"Encoder self-attention and "
|
|
"encoder/decoder cross-attention "
|
|
"are not implemented for "
|
|
"FlashAttnMLAImpl"
|
|
)
|
|
|
|
if is_quantized_kv_cache(self.kv_cache_dtype):
|
|
raise NotImplementedError(
|
|
"FlashAttnMLA V1 with FP8 KV cache not yet supported"
|
|
)
|
|
|
|
def _forward_decode(
|
|
self,
|
|
q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
|
kv_c_and_k_pe_cache: torch.Tensor,
|
|
attn_metadata: FlashAttnMLAMetadata,
|
|
layer: AttentionLayer,
|
|
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
|
assert kv_c_and_k_pe_cache.numel() > 0
|
|
assert attn_metadata.decode is not None
|
|
|
|
if type(q) is tuple:
|
|
q_nope, q_pe = q
|
|
else:
|
|
q_nope, q_pe = torch.split(
|
|
q, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
|
)
|
|
|
|
if self.kv_cache_dtype.startswith("fp8"):
|
|
raise NotImplementedError("FP8 FlashAttention MLA not yet supported")
|
|
|
|
kv_c_cache = kv_c_and_k_pe_cache[..., : self.kv_lora_rank]
|
|
k_pe_cache = kv_c_and_k_pe_cache[..., self.kv_lora_rank :]
|
|
|
|
# NOTE(matt): During CUDA graph capture, max_query_len can be 0, but the
|
|
# kernel uses this to calculate grid dimensions. Ensure it's at least 1
|
|
# to prevent invalid grid configuration during graph capture.
|
|
max_seqlen_q = max(attn_metadata.decode.max_query_len, 1)
|
|
|
|
attn_out = flash_attn_varlen_func(
|
|
q=q_pe,
|
|
k=k_pe_cache.unsqueeze(-2), # Add head dim of 1
|
|
v=kv_c_cache.unsqueeze(-2), # Add head dim of 1
|
|
q_v=q_nope,
|
|
max_seqlen_q=max_seqlen_q,
|
|
cu_seqlens_q=attn_metadata.decode.query_start_loc,
|
|
max_seqlen_k=attn_metadata.decode.max_seq_len,
|
|
seqused_k=attn_metadata.decode.seq_lens,
|
|
block_table=attn_metadata.decode.block_table,
|
|
softmax_scale=self.scale,
|
|
causal=True,
|
|
return_softmax_lse=self.need_to_return_lse_for_decode,
|
|
fa_version=3, # only version 3 is supported
|
|
scheduler_metadata=attn_metadata.decode.scheduler_metadata,
|
|
num_splits=attn_metadata.decode.max_num_splits,
|
|
cp_world_size=self.dcp_world_size,
|
|
cp_rank=self.dcp_rank,
|
|
cp_tot_seqused_k=attn_metadata.decode.dcp_tot_seq_lens,
|
|
)
|
|
|
|
if self.need_to_return_lse_for_decode:
|
|
o, lse = attn_out
|
|
# FA returns LSE in shape [ H, B ] but DCP wants [ B, H ]
|
|
return o, lse.transpose(0, 1) # [ H, B ] -> [ B, H ]
|
|
else:
|
|
o = attn_out
|
|
return o, None
|