mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Attention] Flash MLA for V1 (#13867)
Signed-off-by: Yang Chen <yangche@fb.com> Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: Yang Chen <yangche@fb.com>
This commit is contained in:
@ -161,15 +161,9 @@ class CudaPlatformBase(Platform):
|
||||
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
|
||||
kv_cache_dtype, block_size, use_v1,
|
||||
use_mla) -> str:
|
||||
if use_v1:
|
||||
if use_mla:
|
||||
logger.info("Using Triton MLA backend on V1 engine.")
|
||||
return "vllm.v1.attention.backends.triton_mla.TritonMLABackend"
|
||||
else:
|
||||
logger.info("Using Flash Attention backend on V1 engine.")
|
||||
return ("vllm.v1.attention.backends.flash_attn."
|
||||
"FlashAttentionBackend")
|
||||
if use_mla:
|
||||
# TODO(lucas): refactor to be more concise
|
||||
# we should probably consider factoring out V1 here
|
||||
if selected_backend == _Backend.FLASHMLA:
|
||||
from vllm.attention.backends.flashmla import (
|
||||
is_flashmla_supported)
|
||||
@ -183,11 +177,26 @@ class CudaPlatformBase(Platform):
|
||||
" (currently only supports block size 64).",
|
||||
block_size)
|
||||
else:
|
||||
logger.info("Using FlashMLA backend.")
|
||||
return "vllm.attention.backends.flashmla.FlashMLABackend"
|
||||
if use_v1:
|
||||
logger.info("Using FlashMLA backend on V1 engine.")
|
||||
return ("vllm.v1.attention.backends.mla."
|
||||
"flashmla.FlashMLABackend")
|
||||
else:
|
||||
logger.info("Using FlashMLA backend.")
|
||||
return ("vllm.attention.backends."
|
||||
"flashmla.FlashMLABackend")
|
||||
|
||||
logger.info("Using Triton MLA backend.")
|
||||
return "vllm.attention.backends.triton_mla.TritonMLABackend"
|
||||
if use_v1:
|
||||
logger.info("Using Triton MLA backend on V1 engine.")
|
||||
return ("vllm.v1.attention.backends.mla."
|
||||
"triton_mla.TritonMLABackend")
|
||||
else:
|
||||
logger.info("Using Triton MLA backend.")
|
||||
return "vllm.attention.backends.triton_mla.TritonMLABackend"
|
||||
if use_v1:
|
||||
logger.info("Using Flash Attention backend on V1 engine.")
|
||||
return ("vllm.v1.attention.backends.flash_attn."
|
||||
"FlashAttentionBackend")
|
||||
if selected_backend == _Backend.FLASHINFER:
|
||||
logger.info("Using FlashInfer backend.")
|
||||
return "vllm.attention.backends.flashinfer.FlashInferBackend"
|
||||
|
@ -34,9 +34,8 @@ class _Backend(enum.Enum):
|
||||
TORCH_SDPA = enum.auto()
|
||||
OPENVINO = enum.auto()
|
||||
FLASHINFER = enum.auto()
|
||||
TRITON_MLA = enum.auto()
|
||||
TRITON_MLA_VLLM_V1 = enum.auto()
|
||||
FLASHMLA = enum.auto()
|
||||
TRITON_MLA = enum.auto() # Supported by V1
|
||||
FLASHMLA = enum.auto() # Supported by V1
|
||||
HPU_ATTN = enum.auto()
|
||||
PALLAS = enum.auto()
|
||||
PALLAS_VLLM_V1 = enum.auto()
|
||||
|
@ -333,13 +333,16 @@ class MLACommonMetadata:
|
||||
T = TypeVar("T", bound=MLACommonMetadata)
|
||||
|
||||
|
||||
class MLACommonMetadataBuilder:
|
||||
class MLACommonMetadataBuilder(Generic[T]):
|
||||
"""
|
||||
NOTE: Please read the comment at the top of the file before trying to
|
||||
understand this class
|
||||
"""
|
||||
|
||||
def __init__(self, runner: "GPUModelRunner"):
|
||||
def __init__(self,
|
||||
runner: "GPUModelRunner",
|
||||
cls: Optional[type[T]] = None):
|
||||
self.cls = cls if cls is not None else MLACommonMetadata
|
||||
self.runner = runner
|
||||
scheduler_config = runner.scheduler_config
|
||||
model_config = runner.model_config
|
||||
@ -431,7 +434,7 @@ class MLACommonMetadataBuilder:
|
||||
self._num_prefill_tokens = num_prefill_tokens
|
||||
|
||||
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
|
||||
common_prefix_len: int):
|
||||
common_prefix_len: int) -> T:
|
||||
device = self.runner.device
|
||||
max_seq_len = self.runner.seq_lens_np[:num_reqs].max()
|
||||
query_start_loc = self.runner.query_start_loc_cpu[:num_reqs + 1].to(
|
||||
@ -502,7 +505,7 @@ class MLACommonMetadataBuilder:
|
||||
assert max(context_chunk_seq_tot) <= \
|
||||
self.chunked_prefill_workspace_size
|
||||
|
||||
return MLACommonMetadata(
|
||||
return self.cls(
|
||||
input_positions=input_positions,
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
max_query_len=max_query_len,
|
||||
|
139
vllm/v1/attention/backends/mla/flashmla.py
Normal file
139
vllm/v1/attention/backends/mla/flashmla.py
Normal file
@ -0,0 +1,139 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
|
||||
get_mla_metadata,
|
||||
is_flashmla_supported)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata,
|
||||
MLACommonMetadataBuilder)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class FlashMLABackend(MLACommonBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "FLASHMLA_VLLM_V1"
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> Type["FlashMLAMetadata"]:
|
||||
return FlashMLAMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> Type["FlashMLAMetadataBuilder"]:
|
||||
return FlashMLAMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["FlashMLAImpl"]:
|
||||
return FlashMLAImpl
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlashMLAMetadata(MLACommonMetadata):
|
||||
decode_tile_scheduler_metadata: Optional[Tuple[torch.Tensor,
|
||||
torch.Tensor]] = None
|
||||
decode_num_splits: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
|
||||
def __init__(self, runner):
|
||||
super().__init__(runner, cls=FlashMLAMetadata)
|
||||
|
||||
self.num_q_heads = self.runner.model_config.get_num_attention_heads(
|
||||
self.runner.parallel_config)
|
||||
|
||||
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
|
||||
common_prefix_len: int):
|
||||
m = super().build(num_reqs, num_actual_tokens, max_query_len,
|
||||
common_prefix_len)
|
||||
|
||||
if m.num_decode_tokens is not None and m.num_decode_tokens > 0:
|
||||
m.decode_tile_scheduler_metadata, m.decode_num_splits = \
|
||||
get_mla_metadata(
|
||||
m.seq_lens[:m.num_decode_tokens],
|
||||
self.num_q_heads,
|
||||
1, # MQA for the decode path
|
||||
)
|
||||
|
||||
return m
|
||||
|
||||
|
||||
class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]],
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
# MLA Specific Arguments
|
||||
**mla_args) -> None:
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
blocksparse_params, logits_soft_cap, attn_type,
|
||||
**mla_args)
|
||||
|
||||
assert is_flashmla_supported(), \
|
||||
"FlashMLA is not supported on this device"
|
||||
|
||||
unsupported_features = [
|
||||
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
|
||||
]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"FlashMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, blocksparse_params, "
|
||||
"logits_soft_cap")
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"FlashMLAImpl")
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: FlashMLAMetadata,
|
||||
) -> torch.Tensor:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
raise NotImplementedError("FP8 FlashMLA not yet supported")
|
||||
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)\
|
||||
.unsqueeze(1) # Add seqlen dim of 1 (decode)
|
||||
|
||||
o, _ = flash_mla_with_kvcache(
|
||||
q=q,
|
||||
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
|
||||
block_table=attn_metadata.block_table[:attn_metadata.num_decodes,
|
||||
...],
|
||||
cache_seqlens=attn_metadata.seq_lens[:attn_metadata.
|
||||
num_decode_tokens],
|
||||
head_dim_v=self.kv_lora_rank,
|
||||
tile_scheduler_metadata=attn_metadata.
|
||||
decode_tile_scheduler_metadata,
|
||||
num_splits=attn_metadata.decode_num_splits,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
)
|
||||
|
||||
return self._v_up_proj_and_o_proj(o)
|
Reference in New Issue
Block a user