Files
vllm-dev/vllm/attention/backends/triton_mla.py
Lucas Wilkinson 288cc6c234 [Attention] MLA with chunked prefill (#12639)
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: Patrick Horn <patrick.horn@gmail.com>
Co-authored-by: simon-mo <xmo@berkeley.edu>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
2025-02-21 15:30:12 -08:00

111 lines
3.8 KiB
Python

# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional, Type
import torch
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.backends.mla.common import (MLACommonBackend,
MLACommonImpl,
MLACommonMetadata)
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
class TritonMLABackend(MLACommonBackend):
@staticmethod
def get_name() -> str:
return "TRITON_MLA"
@staticmethod
def get_impl_cls() -> Type["TritonMLAImpl"]:
return TritonMLAImpl
class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]],
logits_soft_cap: Optional[float],
attn_type: str,
# MLA Specific Arguments
**mla_args) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap, attn_type,
**mla_args)
unsupported_features = [
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
]
if any(unsupported_features):
raise NotImplementedError(
"TritonMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, blocksparse_params, "
"logits_soft_cap")
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"TritonMLAImpl")
def _forward_decode(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
) -> torch.Tensor:
assert kv_c_and_k_pe_cache.numel() > 0
if self.kv_cache_dtype.startswith("fp8"):
raise NotImplementedError("FP8 Triton MLA not yet supported")
decode_meta = attn_metadata.decode_metadata
assert decode_meta is not None
B = q_nope.shape[0]
q = torch.cat([q_nope, q_pe], dim=-1)
o = torch.zeros(B,
self.num_heads,
self.kv_lora_rank,
dtype=q.dtype,
device=q.device)
num_kv_splits = 4 # TODO: heuristic
# TODO(lucas) Allocate ahead of time
attn_logits = torch.empty(
(
B,
self.num_heads,
num_kv_splits,
# NOTE(lucas) idk why the +1 is here but sglang has it so we
# just mirror that
self.kv_lora_rank + 1,
),
dtype=torch.float32,
device=q.device,
)
# Add a head dim of 1
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2)
kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank]
PAGE_SIZE = kv_c_and_k_pe_cache.size(1)
# Run MQA
decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o,
decode_meta.block_tables,
decode_meta.seq_lens_tensor, attn_logits,
num_kv_splits, self.scale, PAGE_SIZE)
return self._v_up_proj_and_o_proj(o)