Compare commits

...

12 Commits

Author SHA1 Message Date
0a02744dc8 fix TP 2025-01-31 01:18:56 +00:00
984ffddda6 add cuda graph support to triton_mla attention 2025-01-30 21:12:00 +00:00
135c404fbb review comments
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
2025-01-30 15:11:58 +00:00
7241acbd64 review comments
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
2025-01-30 14:09:46 +00:00
2b140debbb Update vllm/config.py
Co-authored-by: Zhuohan Li <zhuohan123@gmail.com>
2025-01-30 08:51:42 -05:00
2326814c11 renaming for consistency
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
2025-01-30 04:00:26 +00:00
534cd0006d review comments
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
2025-01-30 03:52:59 +00:00
aa19f297d2 Update vllm/attention/backends/mla/utils.py
Co-authored-by: Michael Goin <mgoin64@gmail.com>
2025-01-29 22:51:37 -05:00
4880a43d20 Update utils.py
Co-authored-by: Michael Goin <mgoin64@gmail.com>
2025-01-29 22:46:43 -05:00
3895bba85a more cleanups
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
2025-01-30 03:18:46 +00:00
f23d126a07 fix VLLM_MLA_PERFORM_MATRIX_ABSORPTION=0
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
2025-01-30 03:18:46 +00:00
ec8c1cf732 squashed commits
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: simon-mo <simon.mo@hey.com>
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
2025-01-30 03:18:46 +00:00
29 changed files with 2262 additions and 37 deletions

View File

@ -28,6 +28,11 @@ void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
const std::string& kv_cache_dtype,
torch::Tensor& k_scale, torch::Tensor& v_scale);
void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe,
torch::Tensor& kv_cache, torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype,
torch::Tensor& scale);
// Just for unittest
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
const double scale, const std::string& kv_cache_dtype);

View File

@ -245,6 +245,51 @@ __global__ void reshape_and_cache_flash_kernel(
}
}
}
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
__global__ void concat_and_cache_mla_kernel(
const scalar_t* __restrict__ kv_c, // [num_tokens, kv_lora_rank]
const scalar_t* __restrict__ k_pe, // [num_tokens, pe_dim]
cache_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank
// + pe_dim)]
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int block_stride, //
const int kv_c_stride, //
const int k_pe_stride, //
const int kv_lora_rank, //
const int pe_dim, //
const int block_size, //
const float* scale //
) {
const int64_t token_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping[token_idx];
// NOTE: slot_idx can be -1 if the token is padded
if (slot_idx < 0) {
return;
}
const int64_t block_idx = slot_idx / block_size;
const int64_t block_offset = slot_idx % block_size;
auto copy = [&](const scalar_t* __restrict__ src, cache_t* __restrict__ dst,
int src_stride, int dst_stride, int size, int offset) {
for (int i = threadIdx.x; i < size; i += blockDim.x) {
const int64_t src_idx = token_idx * src_stride + i;
const int64_t dst_idx = block_idx * block_stride +
block_offset * (kv_lora_rank + pe_dim) + i +
offset;
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
dst[dst_idx] = src[src_idx];
} else {
dst[dst_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(src[src_idx], *scale);
}
}
};
copy(kv_c, kv_cache, kv_c_stride, block_stride, kv_lora_rank, 0);
copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank);
}
} // namespace vllm
// KV_T is the stored data type of kv-cache.
@ -343,6 +388,56 @@ void reshape_and_cache_flash(
CALL_RESHAPE_AND_CACHE_FLASH);
}
// KV_T is the stored data type of kv-cache.
// CACHE_T is the data type of key and value tensors.
// KV_DTYPE is the real data type of kv-cache.
#define CALL_CONCAT_AND_CACHE_MLA(KV_T, CACHE_T, KV_DTYPE) \
vllm::concat_and_cache_mla_kernel<KV_T, CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(kv_c.data_ptr()), \
reinterpret_cast<KV_T*>(k_pe.data_ptr()), \
reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), block_stride, kv_c_stride, \
k_pe_stride, kv_lora_rank, pe_dim, block_size, \
reinterpret_cast<const float*>(scale.data_ptr()));
void concat_and_cache_mla(
torch::Tensor& kv_c, // [num_tokens, kv_lora_rank]
torch::Tensor& k_pe, // [num_tokens, pe_dim]
torch::Tensor& kv_cache, // [num_blocks, block_size, (kv_lora_rank +
// pe_dim)]
torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens]
const std::string& kv_cache_dtype, torch::Tensor& scale) {
// NOTE(woosuk): In vLLM V1, key.size(0) can be different from
// slot_mapping.size(0) because of padding for CUDA graphs.
// In vLLM V0, key.size(0) is always equal to slot_mapping.size(0) because
// both include padding.
// In vLLM V1, however, key.size(0) can be larger than slot_mapping.size(0)
// since key includes padding for CUDA graphs, while slot_mapping does not.
// In this case, slot_mapping.size(0) represents the actual number of tokens
// before padding.
// For compatibility with both cases, we use slot_mapping.size(0) as the
// number of tokens.
int num_tokens = slot_mapping.size(0);
int kv_lora_rank = kv_c.size(1);
int pe_dim = k_pe.size(1);
int block_size = kv_cache.size(1);
TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim);
int kv_c_stride = kv_c.stride(0);
int k_pe_stride = k_pe.stride(0);
int block_stride = kv_cache.stride(0);
dim3 grid(num_tokens);
dim3 block(std::min(kv_lora_rank, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(kv_c));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
CALL_CONCAT_AND_CACHE_MLA);
}
namespace vllm {
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>

View File

@ -463,6 +463,15 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
cache_ops.impl("reshape_and_cache_flash", torch::kCUDA,
&reshape_and_cache_flash);
// Concat kv_c and k_pe and cache them.
cache_ops.def(
"concat_and_cache_mla(Tensor kv_c, Tensor k_pe,"
" Tensor! kv_cache,"
" Tensor slot_mapping,"
" str kv_cache_dtype,"
" Tensor scale) -> ()");
cache_ops.impl("concat_and_cache_mla", torch::kCUDA, &concat_and_cache_mla);
// Convert the key and value cache to fp8 data type.
cache_ops.def(
"convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, "

View File

@ -0,0 +1,89 @@
import pytest
import torch
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
def cdiv(a, b):
return (a + b - 1) // b
@pytest.mark.parametrize("B", [3, 5])
@pytest.mark.parametrize("L", [1027, 1025])
@pytest.mark.parametrize("H_Q", [32])
@pytest.mark.parametrize("H_KV", [32, 8])
@pytest.mark.parametrize("D_QK", [128, 192, 576])
@pytest.mark.parametrize("D_V", [128, 512])
@pytest.mark.parametrize("CACHE_SIZE", [16384])
@pytest.mark.parametrize("PAGE_SIZE", [1, 16])
def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
assert CACHE_SIZE % PAGE_SIZE == 0
dtype = torch.bfloat16
seq_len = L # This represents the number of tokens already in the sequence
sm_scale = 1.0 / (D_QK**0.5)
num_kv_splits = 8
num_pages_per_batch = cdiv(seq_len, PAGE_SIZE)
req_to_page = torch.randint(0,
CACHE_SIZE // PAGE_SIZE,
(B, num_pages_per_batch, 1),
device="cuda")
req_to_token = req_to_page * PAGE_SIZE
req_to_token = req_to_token.expand(B, num_pages_per_batch, PAGE_SIZE)
req_to_token = req_to_token + torch.arange(PAGE_SIZE, device="cuda").view(
1, 1, -1)
req_to_token = req_to_token.view(B, -1)
req_to_token = req_to_token[:, :seq_len].contiguous()
# q represents the new token being generated, one per batch
q = torch.randn(B, H_Q, D_QK, dtype=dtype, device="cuda")
# k_buffer and v_buffer represent all previous tokens
# Page size is 1.
k_buffer = torch.randn(CACHE_SIZE, H_KV, D_QK, dtype=dtype, device="cuda")
v_buffer = torch.randn(CACHE_SIZE, H_KV, D_V, dtype=dtype, device="cuda")
# o will have the same shape as q
o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda")
b_seq_len = torch.full((B, ), seq_len, device="cuda")
attn_logits = torch.empty(
(B, H_Q, num_kv_splits, D_V + 1),
dtype=torch.float32,
device="cuda",
)
# Call the original implementation.
decode_attention_fwd(
q,
k_buffer,
v_buffer,
o,
req_to_token,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
)
# Page size can be larger than 1.
k_buffer = k_buffer.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_QK)
v_buffer = v_buffer.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_V)
o1 = torch.zeros_like(o)
decode_attention_fwd(
q,
k_buffer,
v_buffer,
o1,
req_to_page,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
PAGE_SIZE,
)
assert torch.allclose(o, o1)

View File

@ -980,6 +980,19 @@ def reshape_and_cache_flash(
v_scale)
def concat_and_cache_mla(
kv_c: torch.Tensor,
k_pe: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
scale: torch.Tensor,
) -> None:
torch.ops._C_cache_ops.concat_and_cache_mla(kv_c, k_pe, kv_cache,
slot_mapping, kv_cache_dtype,
scale)
def copy_blocks(key_caches: List[torch.Tensor],
value_caches: List[torch.Tensor],
block_mapping: torch.Tensor) -> None:

View File

@ -168,7 +168,8 @@ class AttentionState(ABC, Generic[T]):
@abstractmethod
@contextmanager
def graph_capture(self, max_batch_size: int):
def graph_capture(self, max_batch_size: int,
positions: Optional[torch.Tensor]):
"""Context manager used when capturing CUDA graphs."""
yield
@ -268,9 +269,25 @@ class AttentionImpl(ABC, Generic[T]):
def forward(
self,
layer: AttentionLayer,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
query: torch.Tensor, # For MLA hidden_states_or_cq
key: torch.Tensor, # For MLA kv_c_normed
value: torch.Tensor, # For MLA k_pe
kv_cache: torch.Tensor,
attn_metadata: T,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
raise NotImplementedError
class MLAAttentionImpl(AttentionImpl):
@abstractmethod
def forward(
self,
layer: AttentionLayer,
hidden_states_or_cq: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: T,
output: Optional[torch.Tensor] = None,

View File

@ -213,7 +213,10 @@ class FlashInferState(AttentionState):
return self._decode_wrapper
@contextmanager
def graph_capture(self, max_batch_size: int):
def graph_capture(self, max_batch_size: int,
positions: Optional[torch.Tensor]):
assert positions is None
self._is_graph_capturing = True
self._graph_decode_wrapper = None
self._graph_slot_mapping = torch.full((max_batch_size, ),

View File

View File

@ -0,0 +1,364 @@
from abc import abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
import torch
from vllm import _custom_ops as ops
from vllm import envs
from vllm.attention.backends.abstract import (AttentionLayer,
AttentionMetadata,
MLAAttentionImpl)
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.vllm_flash_attn import flash_attn_varlen_func
@dataclass(kw_only=True)
class MLAMetadataCommon(AttentionMetadata):
# Input positions for rotrary embeddings since for MLA the rotarty
# position embeddings are applied inside the attention backend
input_positions: torch.Tensor
class MLACommonImpl(MLAAttentionImpl):
"""
Common class for implementing repeated parts
Main reference: DeepseekV2 paper, and FlashInfer Implementation
(https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).
Deepseek's MLA attention works the following way:
* Use a single latent vector to represent the entire KV cache.
* The attention "simulates" a multi-head attention, while the compute is
similar to multi-query attention.
* The dataflow is as follows,
* B: batch/sequence length
* H: hidden size
* N: number of attention heads
* Lq: latent dimension for Q
* Lkv: latent dimension for K/V
* P: nope dimension, P+R is the actual head_dim in common attention.
* R: rope dimension, this slide of the head_dim goes through rope.
* V: V head dim.
* kv_c: latent/compressed KV
* q_c: latent/compressed Q
#
# Outside the MLA attention backend
#
1. The hidden states (B, H) are projected down into cq (B, Lq) and
kv_c_k_pe (B, Lkv+R).
2. The kv_c_k_pe is split into kv_c (B, Lkv) and k_pe (B, R). cq
and kv_c are normalized.
#
# Inside the MLA attention backend
#
* if prefill:
3. The q_c is then projected up into the multi-head version.
* q_c goes from (B, Lq) to (B, N, (P+R)), which is split into q_nope
(B, N, P) and q_pe (B, N, R).
4. q_pe, k_pe are then passed through rotary embeddings.
5. kv_c and k_pe are concatenated and inserted into the cache
6. The kv_c is then projected up into the multi-head version.
* kv_c goes from (B, Lkv) to (B, N, (P+V)) which has the nope
dimensions for K and V, which is split into k_nope (B, N, P)
and v (B, N, V).
7. q (B, N, (P+R)) and k (B, N, (P+R)) matrices are assembled from
q_nope, q_pe, k_nope, k_pe.
8. Attention is computued with q, k, v.
9. The attention computation returns (B, N, V), which is projected back
to (B, H) using out projection.
* if decode:
3. Here's the change, we do not perform up the full up projection for
q_c, and there is no up projection at all for kv_c. This is
achieved by the technique of "weight absorption". The paper says
"Fortunately, due to the associative law of matrix multiplication,
we can absorb WUK into WUQ, and WUV into WO"
* The q up projection turns (B, Lq) into (B, N, (P+R)), we split it
into W_UQ (Lq, N, P) and W_QR (Lq, N, R).
* The kv_c up projection turns (B, Lkv) into (B, N, (P+V)), we split
it into W_UK (Lkv, N, P) and W_UV (Lkv, N, V).
* The out projection shape W_O (N*V, H) turns (B, N, V) into (B, H).
* We can precompute the product of W_UQ and W_UK into
W_UQ_UK (Lq, N, Lkv), which is possible due to QK^T operation in
attention.
* We can precompute the product of W_UV and W_O into
W_UV_O (N, Lkv, H), which is possible due to V@O as the
"epilogue" of attention
4. We still need to compute q_pe (B, N, R) by applying W_QR to q_latent.
5. q_pe, k_pe are then passed through rotary embeddings.
6. kv_c and k_pe are concatenated and inserted into the cache
7. By applying W_UQ_UK to q_latent, we have the new q_nope of shape
(B, N, Lkv).
8. q (B, N, (Lkv+R)), k (B, (Lkv+R)) are assembled from q_nope, q_pe,
kv_a, k_pe. v (B, Lkv) is exactly the same vector as kv_a.
9. The attention is computed with q, k, v. Note that we just performed
a MQA attention with (LKv+R) as our head dim.
10. The KV cache is updated using the new entries k (B, N, (Lkv+R)),
which included the v and rope values.
11. The attention computation returns (B, N, Lkv), which is projected
back to (B, H) using W_UV_O.
From @tsu-bin's calculation, we only want to use the absorption technique
for decode. The prefill algorithm should still use the up-projected MHA
for less flops and memory usage.
"""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]],
logits_soft_cap: Optional[float],
attn_type: str,
# MLA Specific Arguments
q_lora_rank: Optional[int],
kv_lora_rank: int,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
qk_head_dim: int,
v_head_dim: int,
rotary_emb: RotaryEmbedding,
# q_proj should be q_b_proj if q_lora_rank is not None, but from an
# attention backend perspective we rely on the layer to pass in the
# correct matrix
q_proj: Optional[ColumnParallelLinear],
kv_b_proj: ColumnParallelLinear,
o_proj: RowParallelLinear,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
self.kv_cache_dtype = kv_cache_dtype
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.qk_head_dim = qk_head_dim
self.v_head_dim = v_head_dim
self.rotary_emb = rotary_emb
self.q_proj = q_proj
self.kv_b_proj = kv_b_proj
self.o_proj = o_proj
def _v_up_proj_and_o_proj(self, x):
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
return self.o_proj_absorbed(
x.reshape(-1, self.num_heads * self.kv_lora_rank))[0]
else:
x = torch.einsum("bnl,lnv->bnv", x, self.W_UV)
return self.o_proj(x.reshape(-1,
self.num_heads * self.v_head_dim))[0]
def _q_proj_and_k_up_proj(self, x):
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
return torch.matmul(x, self.W_Q_UK)\
.view(-1, self.num_heads, self.kv_lora_rank)
else:
x = torch.matmul(x, self.W_Q)\
.view(-1, self.num_heads, self.qk_nope_head_dim)
return torch.einsum("bnp,lnp->bnl", x, self.W_UK)\
.view(-1, self.num_heads, self.kv_lora_rank)
def process_weights_after_loading(self):
kv_b_proj_weight = self.kv_b_proj.weight.T
assert kv_b_proj_weight.shape == (
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), (
f"{kv_b_proj_weight.shape=}, "
f"{self.kv_lora_rank=}, "
f"{self.num_heads=}, "
f"{self.qk_nope_head_dim=}, "
f"{self.v_head_dim=}")
kv_b_proj_weight = kv_b_proj_weight.view(
self.kv_lora_rank,
self.num_heads,
self.qk_nope_head_dim + self.v_head_dim,
)
W_UK, W_UV = kv_b_proj_weight.split(
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
q_proj = self.q_proj.weight.T\
.view(-1, self.num_heads, self.qk_head_dim)
# can be W_Q or W_UQ depending q_lora_rank, the former if
# q_lora_rank is None, the latter otherwise. From the Attention backend
# perspective though we call these both W_Q and rely on the layer
# to pass in the correct matrix
W_Q = q_proj[..., :self.qk_nope_head_dim]
self.W_QR = q_proj[..., self.qk_nope_head_dim:]\
.flatten(start_dim=1).contiguous()
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
#
# Perform matrix-absorption following
# https://github.com/flashinfer-ai/flashinfer/pull/551
# for decode, as a result we end up with absorbed weights for decode
# and another copy of raw weights for prefill.
#
self.W_UK, self.W_UV = kv_b_proj_weight.split(
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
# We absorb `W_UK` into `W_Q` resulting in either W_Q_UK or W_UQ_UK
# depending q_lora_rank, the former if q_lora_rank is None, the
# latter otherwise
# basically if q_lora_rank is none we are absorbing into q_proj
# instead of UQ
self.W_Q_UK = torch.einsum("qnd,lnd -> qnl", W_Q, W_UK)\
.flatten(start_dim=1).contiguous()
W_O = self.o_proj.weight\
.view(-1, self.num_heads, self.v_head_dim)
self.W_UV_O = torch.einsum("lnd,hnd -> nlh", W_UV, W_O)\
.flatten(start_dim=0, end_dim=1).contiguous()
tp_size = get_tensor_model_parallel_world_size()
self.o_proj_absorbed = RowParallelLinear(
self.W_UV_O.shape[0] * tp_size,
self.W_UV_O.shape[1],
bias=False,
# TODO(lucas) figure out how to properly forward quant_method
#quant_config=self.o_proj.quant_method,
)
self.o_proj_absorbed.weight = torch.nn.Parameter(self.W_UV_O.T)
else:
self.W_UV = W_UV
self.W_UK = W_UK
self.W_Q = W_Q.flatten(start_dim=1)
@abstractmethod
def _forward_prefill(
self,
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
attn_metadata: MLAMetadataCommon,
) -> torch.Tensor:
raise NotImplementedError
@abstractmethod
def _forward_decode(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: MLAMetadataCommon,
) -> torch.Tensor:
raise NotImplementedError
def forward(
self,
layer: AttentionLayer,
hidden_states_or_q_c: torch.Tensor, # query in unified attn
k_c_normed: torch.Tensor, # key in unified attn
k_pe: torch.Tensor, # value in unified attn
kv_cache: torch.Tensor,
attn_metadata: MLAMetadataCommon,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if output is not None:
raise NotImplementedError(
"output is not yet supported for MLAImplBase")
is_decode = attn_metadata.decode_metadata is not None
is_prefill = attn_metadata.prefill_metadata is not None
if (is_decode and is_prefill):
raise NotImplementedError(
"chunked prefill is not supported for MLAImplBase")
# Restore head dim (for rotary embedding)
k_pe = k_pe.unsqueeze(1)
if is_decode:
q_nope = self._q_proj_and_k_up_proj(hidden_states_or_q_c)
q_pe = torch.matmul(hidden_states_or_q_c, self.W_QR)\
.view(-1, self.num_heads, self.qk_rope_head_dim)
q_pe, k_pe = \
self.rotary_emb(attn_metadata.input_positions, q_pe, k_pe)
else:
assert is_prefill
q = self.q_proj(hidden_states_or_q_c)[0]\
.view(-1, self.num_heads, self.qk_head_dim)
# TODO(lucas): there must be a nicer way to write this line
q[..., self.qk_nope_head_dim:], k_pe = \
self.rotary_emb(
attn_metadata.input_positions,
q[..., self.qk_nope_head_dim:], k_pe)
# write the latent and rope to kv cache
if kv_cache.numel() > 0:
ops.concat_and_cache_mla(
k_c_normed,
k_pe.squeeze(1),
kv_cache,
attn_metadata.slot_mapping.flatten(),
kv_cache_dtype=self.kv_cache_dtype,
scale=layer._k_scale,
)
if attn_metadata.prefill_metadata is not None:
return self._forward_prefill(q, k_c_normed, k_pe, attn_metadata)
if attn_metadata.decode_metadata is not None:
return self._forward_decode(q_nope, q_pe, kv_cache, attn_metadata)
# Optional common flash-attn based prefill
def _forward_prefill_flash(
self,
q: torch.Tensor,
k_c_normed: torch.Tensor,
k_pe: torch.Tensor,
seq_start_loc: torch.Tensor,
max_prefill_seq_len: int,
) -> torch.Tensor:
kv_nope = self.kv_b_proj(k_c_normed)[0]\
.view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv_nope\
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
# For MLA the v head dim is smaller than qk head dim so we pad out
# v with 0s to match the qk head dim
v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
value=0)
attn_output = flash_attn_varlen_func(
q=q,
k=k,
v=v_padded,
cu_seqlens_q=seq_start_loc,
cu_seqlens_k=seq_start_loc,
max_seqlen_q=max_prefill_seq_len,
max_seqlen_k=max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
)
attn_output = attn_output\
.view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\
.reshape(-1, self.num_heads * v.shape[-1])
return self.o_proj(attn_output)[0]

View File

@ -0,0 +1,742 @@
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
from itertools import accumulate
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
from vllm.multimodal import MultiModalPlaceholderMap
try:
from flashinfer import BatchDecodeMlaWithPagedKVCacheWrapper
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
except ImportError:
BatchDecodeMlaWithPagedKVCacheWrapper = None
FLASHINFER_WORKSPACE_BUFFER_SIZE = 0
import torch
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata,
AttentionMetadataBuilder,
AttentionState, AttentionType)
from vllm.attention.backends.mla.utils import MLACommonImpl, MLAMetadataCommon
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
compute_slot_mapping_start_idx,
is_block_tables_empty)
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
if TYPE_CHECKING:
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
ModelInputForGPUWithSamplingMetadata)
class TritonMLABackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "TRITON_MLA"
@staticmethod
def get_impl_cls() -> Type["TritonMLAImpl"]:
return TritonMLAImpl
@staticmethod
def get_metadata_cls() -> Type["AttentionMetadata"]:
return TritonMLAMetadata
@staticmethod
def get_builder_cls() -> Type["TritonMLAMetadataBuilder"]:
return TritonMLAMetadataBuilder
@staticmethod
def get_state_cls() -> Type["TritonMLAState"]:
return TritonMLAState
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int, # assumed to be 1 for MLA
kv_lora_rank: int, # passed via head_size
) -> Tuple[int, ...]:
# TODO(lucas): remove hardcoding k_pe size as 1/8th of kv_lora_rank
k_pe_size = kv_lora_rank // 8
return (num_blocks, block_size, kv_lora_rank + k_pe_size)
@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: torch.Tensor,
) -> None:
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists)
@staticmethod
def get_supported_head_sizes() -> List[int]:
return [512]
class TritonMLAState(AttentionState):
def __init__(self, runner):
self.runner = runner
self._is_graph_capturing = False
@contextmanager
def graph_capture(self, max_batch_size: int,
positions: Optional[torch.Tensor]):
self._is_graph_capturing = True
self._graph_slot_mapping = torch.full((max_batch_size, ),
PAD_SLOT_ID,
dtype=torch.long,
device=self.runner.device)
self._graph_seq_lens = torch.ones(max_batch_size,
dtype=torch.int32,
device=self.runner.device)
self._graph_block_tables = torch.from_numpy(
self.runner.graph_block_tables).to(device=self.runner.device)
assert positions is not None
self._positions = positions
yield
self._is_graph_capturing = False
del self._graph_slot_mapping
del self._graph_seq_lens
del self._graph_block_tables
def graph_clone(self, batch_size: int):
assert self._is_graph_capturing
return self.__class__(self.runner)
def graph_capture_get_metadata_for_batch(
self, batch_size: int, is_encoder_decoder_model: bool = False):
assert self._is_graph_capturing
attn_metadata = self.runner.attn_backend.make_metadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=batch_size,
slot_mapping=self._graph_slot_mapping[:batch_size],
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
seq_lens=None,
seq_lens_tensor=self._graph_seq_lens[:batch_size],
max_query_len=1,
max_decode_query_len=1,
max_prefill_seq_len=0,
max_decode_seq_len=self.runner.max_seq_len_to_capture,
query_start_loc=None,
seq_start_loc=None,
context_lens_tensor=None,
block_tables=self._graph_block_tables[:batch_size],
use_cuda_graph=True,
input_positions=self._positions[:batch_size],
head_dim=self.runner.model_config.get_head_size())
if is_encoder_decoder_model:
raise NotImplementedError(
"TritonMLAState does not support encoder/decoder yet")
return attn_metadata
def get_graph_input_buffers(self,
attn_metadata,
is_encoder_decoder_model: bool = False):
input_buffers = {
"slot_mapping": attn_metadata.slot_mapping,
"seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor,
"block_tables": attn_metadata.decode_metadata.block_tables,
}
if is_encoder_decoder_model:
raise NotImplementedError(
"TritonMLAState does not support encoder/decoder yet")
return input_buffers
def prepare_graph_input_buffers(self,
input_buffers,
attn_metadata,
is_encoder_decoder_model: bool = False):
input_buffers["seq_lens_tensor"].copy_(
attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True)
input_buffers["block_tables"].copy_(
attn_metadata.decode_metadata.block_tables, non_blocking=True)
if is_encoder_decoder_model:
raise NotImplementedError(
"TritonMLAState does not support encoder/decoder yet")
def begin_forward(self, model_input):
return
@dataclass(kw_only=True)
class TritonMLAMetadata(MLAMetadataCommon):
"""Metadata for TritonMLAMetadata.
NOTE: Any python object stored here is not updated when it is
cuda-graph replayed. If you have values that need to be changed
dynamically, it should be stored in tensor. The tensor has to be
updated from `CUDAGraphRunner.forward` API.
"""
# (batch_size,). The sequence length per sequence. Sequence length means
# the computed tokens + new tokens None if it is a decoding.
seq_lens: Optional[List[int]]
# seq_lens stored as a tensor.
seq_lens_tensor: Optional[torch.Tensor]
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
# Maximum sequence length among prefill batch. 0 if there are decoding
# requests only.
max_prefill_seq_len: int
# Maximum sequence length among decode batch. 0 if there are prefill
# requests only.
max_decode_seq_len: int
# (batch_size,) A tensor of context lengths (tokens that are computed
# so far).
context_lens_tensor: Optional[torch.Tensor]
# (batch_size, max_blocks_per_seq).
# Block addresses per sequence. (Seq id -> list of physical block)
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
# in the kv cache. Each block can contain up to block_size tokens.
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
# captured.
block_tables: Optional[torch.Tensor]
# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool
# Maximum query length in the batch.
max_query_len: Optional[int] = None
# Max number of query tokens among request in the batch.
max_decode_query_len: Optional[int] = None
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
# the batch, used to index into subquery. E.g., if the subquery length
# is [4, 6], it is [0, 4, 10].
query_start_loc: Optional[torch.Tensor] = None
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
# the batch, used to index into sequence. E.g., if the sequence length is
# [4, 6], it is [0, 4, 10].
seq_start_loc: Optional[torch.Tensor] = None
_cached_prefill_metadata: Optional["TritonMLAMetadata"] = None
_cached_decode_metadata: Optional["TritonMLAMetadata"] = None
num_prefill_tokens: int
num_kv_splits: int = 4 # TODO(lucas) add heuristic
attn_logits: Optional[torch.Tensor] = None
req_idx: Optional[torch.Tensor] = None
# The dimension of the attention heads
head_dim: Optional[int] = None
def __post_init__(self):
supported_head_sizes = TritonMLABackend.get_supported_head_sizes()
if self.head_dim is not None and self.head_dim \
not in supported_head_sizes:
raise ValueError(
f"Only {supported_head_sizes} are supported for head_dim,",
f"received {self.head_dim}.")
@property
def prefill_metadata(self) -> Optional["TritonMLAMetadata"]:
if self.num_prefills == 0:
return None
if self._cached_prefill_metadata is not None:
return self._cached_prefill_metadata
assert ((self.seq_lens is not None)
or (self.encoder_seq_lens is not None))
assert ((self.seq_lens_tensor is not None)
or (self.encoder_seq_lens_tensor is not None))
# Compute some attn_metadata fields which default to None
query_start_loc = (None if self.query_start_loc is None else
self.query_start_loc[:self.num_prefills + 1])
slot_mapping = (None if self.slot_mapping is None else
self.slot_mapping[:self.num_prefill_tokens])
seq_lens = (None if self.seq_lens is None else
self.seq_lens[:self.num_prefills])
seq_lens_tensor = (None if self.seq_lens_tensor is None else
self.seq_lens_tensor[:self.num_prefills])
seq_start_loc = (None if self.seq_start_loc is None else
self.seq_start_loc[:self.num_prefills + 1])
context_lens_tensor = (None if self.context_lens_tensor is None else
self.context_lens_tensor[:self.num_prefills])
block_tables = (None if self.block_tables is None else
self.block_tables[:self.num_prefills])
input_positions = (None if self.input_positions is None else
self.input_positions[:self.num_prefill_tokens])
self._cached_prefill_metadata = TritonMLAMetadata(
num_prefills=self.num_prefills,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=0,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=self.
multi_modal_placeholder_index_maps,
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
max_query_len=self.max_query_len,
max_prefill_seq_len=self.max_prefill_seq_len,
max_decode_query_len=0,
max_decode_seq_len=0,
query_start_loc=query_start_loc,
seq_start_loc=seq_start_loc,
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=False,
input_positions=input_positions,
head_dim=self.head_dim)
return self._cached_prefill_metadata
@property
def decode_metadata(self) -> Optional["TritonMLAMetadata"]:
if self.num_decode_tokens == 0:
return None
if self._cached_decode_metadata is not None:
return self._cached_decode_metadata
assert ((self.seq_lens_tensor is not None)
or (self.encoder_seq_lens_tensor is not None))
# Compute some attn_metadata fields which default to None
slot_mapping = (None if self.slot_mapping is None else
self.slot_mapping[self.num_prefill_tokens:])
seq_lens_tensor = (None if self.seq_lens_tensor is None else
self.seq_lens_tensor[self.num_prefills:])
block_tables = (None if self.block_tables is None else
self.block_tables[self.num_prefills:])
input_positions = (None if self.input_positions is None else
self.input_positions[self.num_prefill_tokens:])
self._cached_decode_metadata = TritonMLAMetadata(
num_prefills=0,
num_prefill_tokens=0,
num_decode_tokens=self.num_decode_tokens,
slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
seq_lens=None,
seq_lens_tensor=seq_lens_tensor,
max_decode_query_len=self.max_decode_query_len,
max_query_len=self.max_query_len,
max_prefill_seq_len=0,
max_decode_seq_len=self.max_decode_seq_len,
# Batch may be composed of prefill|decodes, adjust query start
# indices to refer to the start of decodes. E.g.
# in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
query_start_loc=(self.query_start_loc[self.num_prefills:] -
self.query_start_loc[self.num_prefills])
if self.query_start_loc is not None else None,
seq_start_loc=self.seq_start_loc[self.num_prefills:]
if self.seq_start_loc is not None else None,
context_lens_tensor=None,
block_tables=block_tables,
use_cuda_graph=self.use_cuda_graph,
input_positions=input_positions,
head_dim=self.head_dim)
return self._cached_decode_metadata
def advance_step(self,
model_input: "ModelInputForGPUWithSamplingMetadata",
sampled_token_ids: Optional[torch.Tensor],
block_size: int,
num_seqs: int,
num_queries: int,
turn_prefills_into_decodes: bool = False):
"""
Update metadata in-place to advance one decode step.
"""
# When using cudagraph, the num_seqs is padded to the next captured
# batch sized, but num_queries tracks the actual number of requests in
# the batch. For --enforce-eager mode, num_seqs == num_queries
if num_seqs != num_queries:
assert num_seqs > num_queries
assert self.use_cuda_graph
if turn_prefills_into_decodes:
# When Mutli-Step is enabled with Chunked-Prefill, prefills and
# decodes are scheduled together. In the first step, all the
# prefills turn into decodes. This update reflects that
# conversion.
assert self.num_decode_tokens + self.num_prefills == num_seqs
self.num_decode_tokens += self.num_prefills
self.num_prefills = 0
self.num_prefill_tokens = 0
self.max_prefill_seq_len = 0
self.max_query_len = 1
self.slot_mapping = self.slot_mapping[:num_seqs]
else:
assert self.seq_lens is not None
assert self.max_decode_seq_len == max(self.seq_lens)
assert self.num_prefills == 0
assert self.num_prefill_tokens == 0
assert self.num_decode_tokens == num_seqs
assert self.slot_mapping.shape == (num_seqs, )
assert self.seq_lens is not None
assert len(self.seq_lens) == num_seqs
assert self.seq_lens_tensor is not None
assert self.seq_lens_tensor.shape == (num_seqs, )
assert self.max_query_len == 1
assert self.max_prefill_seq_len == 0
assert self.query_start_loc is not None
assert self.query_start_loc.shape == (num_queries + 1, )
assert self.seq_start_loc is not None
assert self.seq_start_loc.shape == (num_seqs + 1, )
assert self.context_lens_tensor is not None
assert self.context_lens_tensor.shape == (num_queries, )
assert self.block_tables is not None
assert self.block_tables.shape[0] == num_seqs
# Update query lengths. Note that we update only queries and not seqs,
# since tensors may be padded due to captured cuda graph batch size
for i in range(num_queries):
self.seq_lens[i] += 1
self.max_decode_seq_len = max(self.seq_lens)
ops.advance_step_flashattn(num_seqs=num_seqs,
num_queries=num_queries,
block_size=block_size,
input_tokens=model_input.input_tokens,
sampled_token_ids=sampled_token_ids,
input_positions=model_input.input_positions,
seq_lens=self.seq_lens_tensor,
slot_mapping=self.slot_mapping,
block_tables=self.block_tables)
class TritonMLAMetadataBuilder(AttentionMetadataBuilder[TritonMLAMetadata]):
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
self.input_builder = input_builder
self.runner = input_builder.runner
self.sliding_window = input_builder.sliding_window
self.block_size = input_builder.block_size
def prepare(self):
self.slot_mapping: List[int] = []
self.prefill_seq_lens: List[int] = []
self.context_lens: List[int] = []
self.block_tables: List[List[int]] = []
self.curr_seq_lens: List[int] = []
self.input_positions: List[int] = []
self.multimodal_placeholder_maps: Dict[
str,
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
self.num_prefills = 0
self.num_prefill_tokens = 0
self.num_decode_tokens = 0
self.has_prefix_cache_hit = False
def _add_seq_group(
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
chunked_prefill_enabled: bool, prefix_cache_hit: bool):
"""Add a sequence group to the metadata. Specifically update/append
1. context length.
2. block table.
3. slot mapping.
"""
is_prompt = inter_data.is_prompt
block_tables = inter_data.block_tables
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
curr_sliding_window_block, input_positions) in zip(
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
inter_data.orig_seq_lens, inter_data.seq_lens,
inter_data.query_lens, inter_data.context_lens,
inter_data.curr_sliding_window_blocks,
inter_data.input_positions):
self.input_positions.extend(input_positions)
self.context_lens.append(context_len)
if is_prompt:
mm_maps = inter_data.multi_modal_placeholder_maps
if mm_maps:
for modality, placeholders in mm_maps.items():
self.multimodal_placeholder_maps[modality].extend(
placeholders)
self.num_prefills += 1
self.num_prefill_tokens += token_len
self.prefill_seq_lens.append(seq_len)
else:
self.num_decode_tokens += query_len
self.curr_seq_lens.append(curr_seq_len)
# Compute block table.
# TODO(sang): Combine chunked prefill and prefix caching by
# only allowing multiple of block_size chunk size.
# NOTE: This only works for oooooooxxx style attention.
block_table = []
if prefix_cache_hit:
# NOTE(woosuk): For flash-attn, the block table should
# include the entries for the incoming prefill tokens.
block_table = block_tables[seq_id]
elif ((chunked_prefill_enabled or not is_prompt)
and block_tables is not None):
if curr_sliding_window_block == 0:
block_table = block_tables[seq_id]
else:
block_table = block_tables[seq_id][
-curr_sliding_window_block:]
self.block_tables.append(block_table)
# Compute slot mapping.
is_profile_run = is_block_tables_empty(block_tables)
start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
context_len,
self.sliding_window)
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
seq_len, context_len, start_idx,
self.block_size, inter_data.block_tables)
def _get_graph_runner_block_tables(
self, num_seqs: int,
block_tables: List[List[int]]) -> torch.Tensor:
# The shape of graph_block_tables is
# [max batch size, max context len // block size].
max_batch_size, max_blocks = self.runner.graph_block_tables.shape
assert max_batch_size >= num_seqs
graph_block_tables = self.runner.graph_block_tables[:num_seqs]
for i, block_table in enumerate(block_tables):
if block_table:
num_blocks = len(block_table)
if num_blocks <= max_blocks:
graph_block_tables[i, :num_blocks] = block_table
else:
# It may be possible to have more blocks allocated due
# to lookahead slots of multi-step, however, they are
# not used anyway, so can be safely ignored.
graph_block_tables[
i, :max_blocks] = block_table[:max_blocks]
return torch.from_numpy(graph_block_tables).to(
device=self.runner.device, non_blocking=True)
def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int):
"""Build attention metadata with on-device tensors.
Args:
seq_lens: The maybe padded sequence lengths of the input sequences.
query_lens: The query lengths of the input sequences.
cuda_graph_pad_size: The padding size for cuda graph.
-1 if cuda graph is not used.
batch_size: The maybe padded batch size.
"""
prefix_cache_hit = any([
inter_data.prefix_cache_hit
for inter_data in self.input_builder.inter_data_list
])
for inter_data in self.input_builder.inter_data_list:
self._add_seq_group(inter_data,
self.input_builder.chunked_prefill_enabled,
prefix_cache_hit)
device = self.runner.device
use_captured_graph = cuda_graph_pad_size != -1
max_query_len = max(query_lens)
decode_query_lens = query_lens[self.num_prefills:]
if len(decode_query_lens) > 0:
max_decode_query_len = max(decode_query_lens)
else:
max_decode_query_len = 1
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
max_decode_seq_len = max(self.curr_seq_lens, default=0)
num_decode_tokens = self.num_decode_tokens
query_start_loc = list(accumulate(query_lens, initial=0))
seq_start_loc = list(accumulate(seq_lens, initial=0))
num_seqs = len(seq_lens)
if use_captured_graph:
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
self.block_tables.extend([] * cuda_graph_pad_size)
num_decode_tokens = batch_size - self.num_prefill_tokens
block_tables = self._get_graph_runner_block_tables(
num_seqs, self.block_tables)
else:
block_tables = make_tensor_with_pad(
self.block_tables,
pad=0,
dtype=torch.int,
device=device,
)
assert max_query_len > 0, ("query_lens: {}".format(query_lens))
assert device is not None
context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int,
device, self.runner.pin_memory)
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
self.runner.pin_memory)
input_positions = async_tensor_h2d(self.input_positions, torch.long,
device, self.runner.pin_memory)
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
device, self.runner.pin_memory)
query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32,
device,
self.runner.pin_memory)
seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
device, self.runner.pin_memory)
placeholder_index_maps = {
modality: placeholder_map.index_map()
for modality, placeholder_map in
self.multimodal_placeholder_maps.items()
}
num_kv_splits = 8
return TritonMLAMetadata(
num_prefills=self.num_prefills,
slot_mapping=slot_mapping_tensor,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
multi_modal_placeholder_index_maps=placeholder_index_maps,
enable_kv_scales_calculation=True,
seq_lens_tensor=seq_lens_tensor,
max_query_len=max_query_len,
max_decode_query_len=max_decode_query_len,
max_prefill_seq_len=max_prefill_seq_len,
max_decode_seq_len=max_decode_seq_len,
query_start_loc=query_start_loc_tensor,
seq_start_loc=seq_start_loc_tensor,
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=use_captured_graph,
input_positions=input_positions,
num_kv_splits=num_kv_splits,
head_dim=self.runner.model_config.get_head_size(),
)
class TritonMLAImpl(MLACommonImpl):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]],
logits_soft_cap: Optional[float],
attn_type: str,
# MLA Specific Arguments
**kwargs) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap, attn_type,
**kwargs)
unsupported_features = [
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
]
if any(unsupported_features):
raise NotImplementedError(
"TritonMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, blocksparse_params, "
"logits_soft_cap")
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"TritonMLAImpl")
def _forward_prefill(
self,
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
attn_metadata: TritonMLAMetadata,
) -> torch.Tensor:
return self._forward_prefill_flash(q, kv_c_normed, k_pe,
attn_metadata.seq_start_loc,
attn_metadata.max_prefill_seq_len)
def _forward_decode(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: TritonMLAMetadata,
) -> torch.Tensor:
assert kv_c_and_k_pe_cache.numel() > 0
if self.kv_cache_dtype.startswith("fp8"):
raise NotImplementedError("FP8 Triton MLA not yet supported")
decode_meta = attn_metadata.decode_metadata
B = q_nope.shape[0]
q = torch.cat([q_nope, q_pe], dim=-1)
o = torch.zeros(B,
self.num_heads,
self.kv_lora_rank,
dtype=q.dtype,
device=q.device)
# TODO(lucas) Allocate ahead of time
attn_logits = torch.empty(
(
B,
self.num_heads,
attn_metadata.num_kv_splits,
# NOTE(lucas) idk why the +1 is here but sglang has it so we
# just mirror that
self.kv_lora_rank + 1,
),
dtype=torch.float32,
device=q.device,
)
# Add a head dim of 1
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2)
kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank]
PAGE_SIZE = kv_c_and_k_pe_cache.size(1)
# Run MQA
decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o,
decode_meta.block_tables,
decode_meta.seq_lens_tensor, attn_logits,
attn_metadata.num_kv_splits, self.scale,
PAGE_SIZE)
return self._v_up_proj_and_o_proj(o)

View File

@ -2,7 +2,7 @@
from collections import defaultdict
from contextlib import contextmanager
from itertools import accumulate
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union
from typing import (TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union, Optional)
import numpy as np
import torch
@ -288,8 +288,11 @@ class CommonAttentionState(AttentionState):
self._is_graph_capturing = False
@contextmanager
def graph_capture(self, max_batch_size: int):
def graph_capture(self, max_batch_size: int, positions: Optional[torch.Tensor]):
assert positions is None
self._is_graph_capturing = True
self._graph_slot_mapping = torch.full((max_batch_size, ),
PAD_SLOT_ID,
dtype=torch.long,
@ -299,7 +302,9 @@ class CommonAttentionState(AttentionState):
device=self.runner.device)
self._graph_block_tables = torch.from_numpy(
self.runner.graph_block_tables).to(device=self.runner.device)
yield
self._is_graph_capturing = False
del self._graph_slot_mapping
del self._graph_seq_lens

View File

@ -41,8 +41,10 @@ class Attention(nn.Module):
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
per_layer_sliding_window: Optional[int] = None,
use_mla: bool = False,
prefix: str = "",
attn_type: str = AttentionType.DECODER,
**extra_impl_args,
) -> None:
super().__init__()
if per_layer_sliding_window is not None:
@ -101,13 +103,18 @@ class Attention(nn.Module):
# During model initialization, the default dtype is set as the model
# weight and activation dtype.
dtype = torch.get_default_dtype()
attn_backend = get_attn_backend(head_size, dtype, kv_cache_dtype,
block_size, is_attention_free,
blocksparse_params is not None)
attn_backend = get_attn_backend(head_size,
dtype,
kv_cache_dtype,
block_size,
is_attention_free,
blocksparse_params is not None,
use_mla=use_mla)
impl_cls = attn_backend.get_impl_cls()
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap, attn_type)
blocksparse_params, logits_soft_cap, attn_type,
**extra_impl_args)
self.num_heads = num_heads
self.head_size = head_size
self.num_kv_heads = num_kv_heads
@ -193,6 +200,10 @@ class Attention(nn.Module):
s += f", backend={self.impl.__class__.__name__}"
return s
def process_weights_after_loading(self):
if hasattr(self.impl, "process_weights_after_loading"):
self.impl.process_weights_after_loading()
class MultiHeadAttention(nn.Module):
"""Multi-headed attention without any cache, used for ViT."""

View File

@ -0,0 +1,667 @@
# Adapted from
# https://github.com/sgl-project/sglang/blob/9f635ea50de920aa507f486daafba26a5b837574/python/sglang/srt/layers/attention/triton_ops/decode_attention.py
# which was originally adapted from
# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py
# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage2.py
# Changes:
# - Add support for page size >= 1.
# Copyright 2025 vLLM Team
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Memory-efficient attention for decoding.
It supports page size >= 1.
"""
import logging
import triton
import triton.language as tl
from vllm.platforms import current_platform
is_hip_ = current_platform.is_rocm()
logger = logging.getLogger(__name__)
# TODO: Remove this when triton>=3.2.0. This issue will not affect performance
# and accuracy.
logger.warning(
"The following error message 'operation scheduled before its operands' "
"can be ignored.")
@triton.jit
def tanh(x):
# Tanh is just a scaled sigmoid
return 2 * tl.sigmoid(2 * x) - 1
@triton.jit
def _fwd_kernel_stage1(
Q,
K_Buffer,
V_Buffer,
sm_scale,
Req_to_tokens,
B_Seqlen,
Att_Out,
stride_req_to_tokens_b,
stride_qbs,
stride_qh,
stride_buf_kbs,
stride_buf_kh,
stride_buf_vbs,
stride_buf_vh,
stride_mid_ob,
stride_mid_oh,
stride_mid_os,
kv_group_num: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_DV: tl.constexpr,
BLOCK_N: tl.constexpr,
NUM_KV_SPLITS: tl.constexpr,
PAGE_SIZE: tl.constexpr,
logit_cap: tl.constexpr,
Lk: tl.constexpr,
Lv: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
split_kv_id = tl.program_id(2)
cur_kv_head = cur_head // kv_group_num
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_dv = tl.arange(0, BLOCK_DV)
mask_d = offs_d < Lk
mask_dv = offs_dv < Lv
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_req_idx = cur_batch
off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d
q = tl.load(Q + off_q, mask=mask_d, other=0.0)
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
split_kv_start = kv_len_per_split * split_kv_id
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split,
cur_batch_seq_len)
e_max = -float("inf")
e_sum = 0.0
acc = tl.zeros([BLOCK_DV], dtype=tl.float32)
if split_kv_end > split_kv_start:
for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
offs_n = start_n + tl.arange(0, BLOCK_N)
kv_page_number = tl.load(
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx +
offs_n // PAGE_SIZE,
mask=offs_n < split_kv_end,
other=0,
)
kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE
offs_buf_k = (kv_loc[:, None] * stride_buf_kbs +
cur_kv_head * stride_buf_kh + offs_d[None, :])
k = tl.load(
K_Buffer + offs_buf_k,
mask=(offs_n[:, None] < split_kv_end) & (mask_d[None, :]),
other=0.0,
)
qk = tl.sum(q[None, :] * k, 1)
qk *= sm_scale
if logit_cap > 0:
qk = logit_cap * tanh(qk / logit_cap)
qk = tl.where(offs_n < split_kv_end, qk, float("-inf"))
offs_buf_v = (kv_loc[:, None] * stride_buf_vbs +
cur_kv_head * stride_buf_vh + offs_dv[None, :])
v = tl.load(
V_Buffer + offs_buf_v,
mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]),
other=0.0,
)
n_e_max = tl.maximum(tl.max(qk, 0), e_max)
re_scale = tl.exp(e_max - n_e_max)
p = tl.exp(qk - n_e_max)
acc *= re_scale
acc += tl.sum(p[:, None] * v, 0)
e_sum = e_sum * re_scale + tl.sum(p, 0)
e_max = n_e_max
offs_mid_o = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh +
split_kv_id * stride_mid_os + offs_dv)
tl.store(
Att_Out + offs_mid_o,
acc / e_sum,
mask=(mask_dv),
)
offs_mid_o_1 = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh +
split_kv_id * stride_mid_os + Lv)
tl.store(
Att_Out + offs_mid_o_1,
e_max + tl.log(e_sum),
)
def _decode_att_m_fwd(
q,
k_buffer,
v_buffer,
att_out,
Req_to_tokens,
B_Seqlen,
num_kv_splits,
sm_scale,
page_size,
logit_cap,
):
BLOCK = 64
NUM_KV_SPLITS = num_kv_splits
Lk = k_buffer.shape[-1]
Lv = v_buffer.shape[-1]
batch, head_num = q.shape[0], q.shape[1]
grid = (batch, head_num, NUM_KV_SPLITS)
kv_group_num = q.shape[1] // k_buffer.shape[-2]
num_warps = 4 if kv_group_num == 1 else 2
BLOCK_DMODEL = triton.next_power_of_2(Lk)
BLOCK_DV = triton.next_power_of_2(Lv)
_fwd_kernel_stage1[grid](
q,
k_buffer,
v_buffer,
sm_scale,
Req_to_tokens,
B_Seqlen,
att_out,
Req_to_tokens.stride(0),
q.stride(0),
q.stride(1),
k_buffer.stride(-2),
k_buffer.stride(-1),
v_buffer.stride(-2),
v_buffer.stride(-1),
att_out.stride(0),
att_out.stride(1),
att_out.stride(2),
kv_group_num=kv_group_num,
BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_DV=BLOCK_DV,
BLOCK_N=BLOCK,
NUM_KV_SPLITS=NUM_KV_SPLITS,
PAGE_SIZE=page_size,
logit_cap=logit_cap,
num_warps=num_warps,
num_stages=2,
Lk=Lk,
Lv=Lv,
)
@triton.jit
def _fwd_grouped_kernel_stage1(
Q,
K_Buffer,
V_Buffer,
sm_scale,
Req_to_tokens,
B_Seqlen,
Att_Out,
stride_req_to_tokens_b,
stride_qbs,
stride_qh,
stride_buf_kbs,
stride_buf_kh,
stride_buf_vbs,
stride_buf_vh,
stride_mid_ob,
stride_mid_oh,
stride_mid_os,
kv_group_num: tl.constexpr,
q_head_num: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_DPE: tl.constexpr,
BLOCK_DV: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_H: tl.constexpr,
NUM_KV_SPLITS: tl.constexpr,
PAGE_SIZE: tl.constexpr,
logit_cap: tl.constexpr,
Lk: tl.constexpr,
Lv: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head_id = tl.program_id(1)
cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H)
split_kv_id = tl.program_id(2)
if kv_group_num > BLOCK_H:
VALID_BLOCK_H: tl.constexpr = BLOCK_H
else:
VALID_BLOCK_H: tl.constexpr = kv_group_num
cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H)
mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H
mask_h = mask_h & (cur_head < q_head_num)
offs_d = tl.arange(0, BLOCK_DMODEL)
offs_dv = tl.arange(0, BLOCK_DV)
mask_d = offs_d < Lk
mask_dv = offs_dv < Lv
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_req_idx = cur_batch
offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[
None, :]
q = tl.load(Q + offs_q,
mask=(mask_h[:, None]) & (mask_d[None, :]),
other=0.0)
if BLOCK_DPE > 0:
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
mask_dpe = offs_dpe < Lk
off_qpe = (cur_batch * stride_qbs + cur_head[:, None] * stride_qh +
offs_dpe[None, :])
qpe = tl.load(Q + off_qpe,
mask=(mask_h[:, None]) & (mask_dpe[None, :]),
other=0.0)
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
split_kv_start = kv_len_per_split * split_kv_id
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split,
cur_batch_seq_len)
e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf")
e_sum = tl.zeros([BLOCK_H], dtype=tl.float32)
acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32)
if split_kv_end > split_kv_start:
for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
offs_n = start_n + tl.arange(0, BLOCK_N)
kv_page_number = tl.load(
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx +
offs_n // PAGE_SIZE,
mask=offs_n < split_kv_end,
other=0,
)
kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE
offs_buf_k = (kv_loc[None, :] * stride_buf_kbs +
cur_kv_head * stride_buf_kh + offs_d[:, None])
k = tl.load(
K_Buffer + offs_buf_k,
mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]),
other=0.0,
)
qk = tl.dot(q, k.to(q.dtype))
if BLOCK_DPE > 0:
offs_buf_kpe = (kv_loc[None, :] * stride_buf_kbs +
cur_kv_head * stride_buf_kh +
offs_dpe[:, None])
kpe = tl.load(
K_Buffer + offs_buf_kpe,
mask=(offs_n[None, :] < split_kv_end) &
(mask_dpe[:, None]),
other=0.0,
)
qk += tl.dot(qpe, kpe.to(qpe.dtype))
qk *= sm_scale
if logit_cap > 0:
qk = logit_cap * tanh(qk / logit_cap)
qk = tl.where(mask_h[:, None] & (offs_n[None, :] < split_kv_end),
qk, float("-inf"))
offs_buf_v = (kv_loc[:, None] * stride_buf_vbs +
cur_kv_head * stride_buf_vh + offs_dv[None, :])
v = tl.load(
V_Buffer + offs_buf_v,
mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]),
other=0.0,
)
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
re_scale = tl.exp(e_max - n_e_max)
p = tl.exp(qk - n_e_max[:, None])
acc *= re_scale[:, None]
acc += tl.dot(p.to(v.dtype), v)
e_sum = e_sum * re_scale + tl.sum(p, 1)
e_max = n_e_max
offs_mid_o = (cur_batch * stride_mid_ob +
cur_head[:, None] * stride_mid_oh +
split_kv_id * stride_mid_os + offs_dv[None, :])
tl.store(
Att_Out + offs_mid_o,
acc / e_sum[:, None],
mask=(mask_h[:, None]) & (mask_dv[None, :]),
)
offs_mid_o_1 = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh +
split_kv_id * stride_mid_os + Lv)
tl.store(
Att_Out + offs_mid_o_1,
e_max + tl.log(e_sum),
mask=mask_h,
)
def _decode_grouped_att_m_fwd(
q,
k_buffer,
v_buffer,
att_out,
Req_to_tokens,
B_Seqlen,
num_kv_splits,
sm_scale,
page_size,
logit_cap,
):
BLOCK = 32
Lk = k_buffer.shape[-1]
Lv = v_buffer.shape[-1]
# [TODO] work around shmem limit on MI3xx
if is_hip_ and Lk >= 576:
BLOCK = 16
if Lk == 576:
BLOCK_DMODEL = 512
BLOCK_DPE = 64
elif Lk == 288:
BLOCK_DMODEL = 256
BLOCK_DPE = 32
else:
BLOCK_DMODEL = triton.next_power_of_2(Lk)
BLOCK_DPE = 0
BLOCK_DV = triton.next_power_of_2(Lv)
batch, head_num = q.shape[0], q.shape[1]
kv_group_num = q.shape[1] // k_buffer.shape[-2]
BLOCK_H = 16
NUM_KV_SPLITS = num_kv_splits
grid = (
batch,
triton.cdiv(head_num, min(BLOCK_H, kv_group_num)),
NUM_KV_SPLITS,
)
extra_kargs = {}
if is_hip_:
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
extra_kargs = {
"waves_per_eu": 4,
"matrix_instr_nonkdim": 16,
"kpack": 2
}
_fwd_grouped_kernel_stage1[grid](
q,
k_buffer,
v_buffer,
sm_scale,
Req_to_tokens,
B_Seqlen,
att_out,
Req_to_tokens.stride(0),
q.stride(0),
q.stride(1),
k_buffer.stride(-2),
k_buffer.stride(-1),
v_buffer.stride(-2),
v_buffer.stride(-1),
att_out.stride(0),
att_out.stride(1),
att_out.stride(2),
kv_group_num=kv_group_num,
q_head_num=head_num,
BLOCK_DMODEL=BLOCK_DMODEL,
BLOCK_DPE=BLOCK_DPE,
BLOCK_DV=BLOCK_DV,
BLOCK_N=BLOCK,
BLOCK_H=BLOCK_H,
NUM_KV_SPLITS=NUM_KV_SPLITS,
PAGE_SIZE=page_size,
logit_cap=logit_cap,
num_warps=4,
num_stages=2,
Lk=Lk,
Lv=Lv,
**extra_kargs,
)
@triton.jit
def _fwd_kernel_stage2(
Mid_O,
o,
B_Seqlen,
stride_mid_ob,
stride_mid_oh,
stride_mid_os,
stride_obs,
stride_oh,
NUM_KV_SPLITS: tl.constexpr,
BLOCK_DV: tl.constexpr,
Lv: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
offs_d = tl.arange(0, BLOCK_DV)
mask_d = offs_d < Lv
e_sum = 0.0
e_max = -float("inf")
acc = tl.zeros([BLOCK_DV], dtype=tl.float32)
offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d
offs_logic = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + Lv
for split_kv_id in range(0, NUM_KV_SPLITS):
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
split_kv_start = kv_len_per_split * split_kv_id
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split,
cur_batch_seq_len)
if split_kv_end > split_kv_start:
tv = tl.load(Mid_O + offs_v + split_kv_id * stride_mid_os,
mask=mask_d,
other=0.0)
tlogic = tl.load(Mid_O + offs_logic + split_kv_id * stride_mid_os)
n_e_max = tl.maximum(tlogic, e_max)
old_scale = tl.exp(e_max - n_e_max)
acc *= old_scale
exp_logic = tl.exp(tlogic - n_e_max)
acc += exp_logic * tv
e_sum = e_sum * old_scale + exp_logic
e_max = n_e_max
tl.store(
o + cur_batch * stride_obs + cur_head * stride_oh + offs_d,
acc / e_sum,
mask=mask_d,
)
def _decode_softmax_reducev_fwd(
logits,
q,
o,
v_buffer,
b_seq_len,
num_kv_splits,
):
batch, head_num = q.shape[0], q.shape[1]
Lv = v_buffer.shape[-1]
BLOCK_DV = triton.next_power_of_2(Lv)
NUM_KV_SPLITS = num_kv_splits
extra_kargs = {}
if is_hip_:
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
extra_kargs = {
"waves_per_eu": 4,
"matrix_instr_nonkdim": 16,
"kpack": 2
}
grid = (batch, head_num)
_fwd_kernel_stage2[grid](
logits,
o,
b_seq_len,
logits.stride(0),
logits.stride(1),
logits.stride(2),
o.stride(0),
o.stride(1),
NUM_KV_SPLITS=NUM_KV_SPLITS,
BLOCK_DV=BLOCK_DV,
Lv=Lv,
num_warps=4,
num_stages=2,
**extra_kargs,
)
def decode_attention_fwd_normal(
q,
k_buffer,
v_buffer,
o,
req_to_token,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
page_size,
logit_cap=0.0,
):
_decode_att_m_fwd(
q,
k_buffer,
v_buffer,
attn_logits,
req_to_token,
b_seq_len,
num_kv_splits,
sm_scale,
page_size,
logit_cap,
)
_decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len,
num_kv_splits)
def decode_attention_fwd_grouped(
q,
k_buffer,
v_buffer,
o,
req_to_token,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
page_size,
logit_cap=0.0,
):
_decode_grouped_att_m_fwd(
q,
k_buffer,
v_buffer,
attn_logits,
req_to_token,
b_seq_len,
num_kv_splits,
sm_scale,
page_size,
logit_cap,
)
_decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len,
num_kv_splits)
def decode_attention_fwd(
q,
k_buffer,
v_buffer,
o,
req_to_token,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
page_size=1,
logit_cap=0.0,
):
assert num_kv_splits == attn_logits.shape[2]
kv_group_num = q.shape[1] // v_buffer.shape[-2]
if kv_group_num == 1:
# MHA
decode_attention_fwd_normal(
q,
k_buffer,
v_buffer,
o,
req_to_token,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
page_size,
logit_cap,
)
else:
# GQA/MQA/MLA
decode_attention_fwd_grouped(
q,
k_buffer,
v_buffer,
o,
req_to_token,
b_seq_len,
attn_logits,
num_kv_splits,
sm_scale,
page_size,
logit_cap,
)

View File

@ -83,6 +83,7 @@ def get_attn_backend(
block_size: int,
is_attention_free: bool,
is_blocksparse: bool = False,
use_mla: bool = False,
) -> Type[AttentionBackend]:
"""Selects which attention backend to use and lazily imports it."""
# Accessing envs.* behind an @lru_cache decorator can cause the wrong
@ -97,6 +98,7 @@ def get_attn_backend(
is_attention_free=is_attention_free,
is_blocksparse=is_blocksparse,
use_v1=envs.VLLM_USE_V1,
use_mla=use_mla,
)
@ -109,6 +111,7 @@ def _cached_get_attn_backend(
is_attention_free: bool,
is_blocksparse: bool = False,
use_v1: bool = False,
use_mla: bool = False,
) -> Type[AttentionBackend]:
if is_blocksparse:
logger.info("Using BlocksparseFlashAttention backend.")
@ -141,7 +144,8 @@ def _cached_get_attn_backend(
# get device-specific attn_backend
attention_cls = current_platform.get_attn_backend_cls(
selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1)
selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1,
use_mla)
if not attention_cls:
raise ValueError(
f"Invalid attention backend for {current_platform.device_name}")

View File

@ -736,17 +736,24 @@ class ModelConfig:
def get_hidden_size(self) -> int:
return self.hf_text_config.hidden_size
@property
def is_deepseek_mla(self) -> bool:
return hasattr(self.hf_text_config,
"model_type") and (self.hf_text_config.model_type
in ('deepseek_v2', 'deepseek_v3'))
def get_head_size(self) -> int:
# TODO remove hard code
if hasattr(self.hf_text_config,
"model_type") and (self.hf_text_config.model_type
in ('deepseek_v2', 'deepseek_v3')):
qk_rope_head_dim = getattr(self.hf_text_config, "qk_rope_head_dim",
0)
qk_nope_head_dim = getattr(self.hf_text_config, "qk_nope_head_dim",
0)
if qk_rope_head_dim and qk_nope_head_dim:
return qk_rope_head_dim + qk_nope_head_dim
if self.is_deepseek_mla:
if self.use_mla:
return self.hf_text_config.kv_lora_rank
else:
qk_rope_head_dim = getattr(self.hf_text_config,
"qk_rope_head_dim", 0)
qk_nope_head_dim = getattr(self.hf_text_config,
"qk_nope_head_dim", 0)
if qk_rope_head_dim and qk_nope_head_dim:
return qk_rope_head_dim + qk_nope_head_dim
if self.is_attention_free:
return 0
@ -805,6 +812,10 @@ class ModelConfig:
def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
"""Returns the number of KV heads per GPU."""
if self.use_mla:
# When using MLA during decode it becomes MQA
return 1
total_num_kv_heads = self.get_total_num_kv_heads()
# If tensor parallelism is used, we divide the number of KV heads by
# the tensor parallel size. We will replicate the KV heads in the
@ -956,6 +967,10 @@ class ModelConfig:
return ModelRegistry.is_cross_encoder_model(architectures)
@property
def use_mla(self) -> bool:
use_mla = (self.is_deepseek_mla and not envs.VLLM_MLA_DISABLE)
return use_mla
def supported_runner_types(self) -> Set[RunnerType]:
return {_TASK_RUNNER[task] for task in self.supported_tasks}

View File

@ -931,7 +931,6 @@ class EngineArgs:
type=str,
default="auto",
help='The worker class to use for distributed execution.')
parser.add_argument(
"--generation-config",
type=nullable_str,

View File

@ -77,6 +77,8 @@ if TYPE_CHECKING:
V_SCALE_CONSTANT: int = 100
VLLM_SERVER_DEV_MODE: bool = False
VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: int = 128
VLLM_MLA_DISABLE: bool = False
VLLM_MLA_PERFORM_MATRIX_ABSORPTION: bool = True
def get_default_cache_root():
@ -506,6 +508,18 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# TTFT and overall throughput.
"VLLM_V1_OUTPUT_PROC_CHUNK_SIZE":
lambda: int(os.getenv("VLLM_V1_OUTPUT_PROC_CHUNK_SIZE", "128")),
# If set, vLLM will disable the MLA attention optimizations.
"VLLM_MLA_DISABLE":
lambda: bool(int(os.getenv("VLLM_MLA_DISABLE", "0"))),
# Flag that can control whether or not we perform matrix-absorption for MLA
# decode, i.e. absorb W_UK into W_Q/W_UK and W_UV into W_O, absorbing the
# matrices reduces the runtime FLOPs needed to compute MLA but requires
# storing more weights, W_Q_UK and W_UV_O, so can increase memory usage,
# the is enabled by default
"VLLM_MLA_PERFORM_MATRIX_ABSORPTION":
lambda: bool(int(os.getenv("VLLM_MLA_PERFORM_MATRIX_ABSORPTION", "1")))
}
# end-env-vars-definition

View File

@ -23,6 +23,7 @@ from torch import nn
from transformers import AutoModelForCausalLM
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from vllm.attention import Attention
from vllm.config import (LoadConfig, LoadFormat, ModelConfig, ParallelConfig,
VllmConfig, set_current_vllm_config)
from vllm.distributed import (get_tensor_model_parallel_rank,
@ -397,6 +398,11 @@ class DefaultModelLoader(BaseModelLoader):
# parameters onto device for processing and back off after.
with device_loading_context(module, target_device):
quant_method.process_weights_after_loading(module)
elif isinstance(module, Attention) and \
hasattr(module, "process_weights_after_loading"):
# When attention modules need to process weights after
# currently only used by MLA
module.process_weights_after_loading()
return model.eval()

View File

@ -28,7 +28,7 @@ from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import (get_pp_group,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
@ -326,12 +326,156 @@ class DeepseekV2Attention(nn.Module):
return output
class DeepseekV2MLAAttention(nn.Module):
"""
Main reference: DeepseekV2 paper, and FlashInfer Implementation
(https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).
For more info see MLACommonImpl in: vllm/attention/backends/mla/utils.py
"""
def __init__(
self,
config: PretrainedConfig,
hidden_size: int,
num_heads: int,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
v_head_dim: int,
q_lora_rank: Optional[int],
kv_lora_rank: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = hidden_size
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
self.v_head_dim = v_head_dim
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.num_heads = num_heads
tp_size = get_tensor_model_parallel_world_size()
assert num_heads % tp_size == 0
self.num_local_heads = num_heads // tp_size
self.scaling = self.qk_head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
if self.q_lora_rank is not None:
self.q_a_proj = ReplicatedLinear(self.hidden_size,
self.q_lora_rank,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_a_proj")
self.q_a_layernorm = RMSNorm(self.q_lora_rank,
eps=config.rms_norm_eps)
self.q_b_proj = ColumnParallelLinear(q_lora_rank,
self.num_heads *
self.qk_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_b_proj")
else:
self.q_proj = ColumnParallelLinear(self.hidden_size,
self.num_heads *
self.qk_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_proj")
self.kv_a_proj_with_mqa = ReplicatedLinear(
self.hidden_size,
self.kv_lora_rank + self.qk_rope_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.kv_a_proj_with_mqa")
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
eps=config.rms_norm_eps)
self.kv_b_proj = ColumnParallelLinear(
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.kv_b_proj")
self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim,
self.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj")
rope_scaling["rope_type"] = 'deepseek_yarn'
self.rotary_emb = get_rope(qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
is_neox_style=False)
if rope_scaling:
mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
scaling_factor = rope_scaling["factor"]
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
self.scaling = self.scaling * mscale * mscale
self.mla_attn = Attention(
num_heads=self.num_local_heads,
head_size=self.kv_lora_rank,
scale=self.scaling,
num_kv_heads=1,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
use_mla=True,
# MLA Args
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,
qk_nope_head_dim=self.qk_nope_head_dim,
qk_rope_head_dim=self.qk_rope_head_dim,
qk_head_dim=self.qk_head_dim,
v_head_dim=self.v_head_dim,
rotary_emb=self.rotary_emb,
q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj,
kv_b_proj=self.kv_b_proj,
o_proj=self.o_proj,
)
self.prefix = prefix
self.debug_layer_idx = int(self.prefix.split(".")[-2])
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
if self.q_lora_rank is not None:
ckq = self.q_a_proj(hidden_states)[0]
hidden_states_or_q_c = self.q_a_layernorm(ckq)
else:
hidden_states_or_q_c = hidden_states
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
return self.mla_attn(hidden_states_or_q_c, kv_c_normed, k_pe, kv_cache,
attn_metadata)
class DeepseekV2DecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
prefix: str,
model_config: ModelConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
@ -344,7 +488,11 @@ class DeepseekV2DecoderLayer(nn.Module):
# DecoderLayers are created with `make_layers` which passes the prefix
# with the layer's index.
layer_idx = int(prefix.split(sep='.')[-1])
self.self_attn = DeepseekV2Attention(
if model_config.use_mla:
attn_cls = DeepseekV2MLAAttention
else:
attn_cls = DeepseekV2Attention
self.self_attn = attn_cls(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
@ -421,6 +569,7 @@ class DeepseekV2Model(nn.Module):
super().__init__()
config = vllm_config.model_config.hf_config
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
@ -440,6 +589,7 @@ class DeepseekV2Model(nn.Module):
lambda prefix: DeepseekV2DecoderLayer(
config,
prefix,
model_config=model_config,
cache_config=cache_config,
quant_config=quant_config,
),

View File

@ -31,7 +31,8 @@ class CpuPlatform(Platform):
@classmethod
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str],
block_size: int, use_v1: bool) -> str:
block_size: int, use_v1: bool,
use_mla: bool) -> str:
if selected_backend != _Backend.TORCH_SDPA:
logger.info("Cannot use %s backend on CPU.", selected_backend)
logger.info("Using Torch SDPA backend.")

View File

@ -157,10 +157,14 @@ class CudaPlatformBase(Platform):
@classmethod
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
kv_cache_dtype, block_size, use_v1) -> str:
kv_cache_dtype, block_size, use_v1,
use_mla) -> str:
if use_v1:
logger.info("Using Flash Attention backend on V1 engine.")
return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
if use_mla:
logger.info("Using Triton MLA backend.")
return "vllm.attention.backends.triton_mla.TritonMLABackend"
if selected_backend == _Backend.FLASHINFER:
logger.info("Using FlashInfer backend.")
return "vllm.attention.backends.flashinfer.FlashInferBackend"
@ -171,7 +175,8 @@ class CudaPlatformBase(Platform):
pass
elif selected_backend:
raise ValueError(
f"Invalid attention backend for {cls.device_name}")
f"Invalid attention backend for {cls.device_name}, "
f"with use_v1: {use_v1} use_mla: {use_mla}")
target_backend = _Backend.FLASH_ATTN
if not cls.has_device_capability(80):

View File

@ -27,7 +27,8 @@ class HpuPlatform(Platform):
@classmethod
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str],
block_size: int, use_v1: bool) -> str:
block_size: int, use_v1: bool,
use_mla: bool) -> str:
logger.info("Using HPUAttention backend.")
return "vllm.attention.backends.hpu_attn.HPUAttentionBackend"

View File

@ -30,6 +30,7 @@ class _Backend(enum.Enum):
TORCH_SDPA = enum.auto()
OPENVINO = enum.auto()
FLASHINFER = enum.auto()
TRITON_MLA = enum.auto()
HPU_ATTN = enum.auto()
PALLAS = enum.auto()
IPEX = enum.auto()
@ -139,7 +140,8 @@ class Platform:
@classmethod
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str],
block_size: int, use_v1: bool) -> str:
block_size: int, use_v1: bool,
use_mla: bool) -> str:
"""Get the attention backend class of a device."""
return ""

View File

@ -30,7 +30,8 @@ class OpenVinoPlatform(Platform):
@classmethod
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str],
block_size: int, use_v1: bool) -> str:
block_size: int, use_v1: bool,
use_mla: bool) -> str:
if selected_backend != _Backend.OPENVINO:
logger.info("Cannot use %s backend on OpenVINO.", selected_backend)
logger.info("Using OpenVINO Attention backend.")

View File

@ -75,7 +75,8 @@ class RocmPlatform(Platform):
@classmethod
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
kv_cache_dtype, block_size, use_v1) -> str:
kv_cache_dtype, block_size, use_v1,
use_mla) -> str:
selected_backend = (_Backend.ROCM_FLASH if selected_backend
== _Backend.FLASH_ATTN else selected_backend)
if selected_backend == _Backend.ROCM_FLASH:

View File

@ -29,7 +29,8 @@ class TpuPlatform(Platform):
@classmethod
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str],
block_size: int, use_v1: bool) -> str:
block_size: int, use_v1: bool,
use_mla: bool) -> str:
if selected_backend != _Backend.PALLAS:
logger.info("Cannot use %s backend on TPU.", selected_backend)
logger.info("Using Pallas backend.")

View File

@ -27,7 +27,8 @@ class XPUPlatform(Platform):
@classmethod
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str],
block_size: int, use_v1: bool) -> str:
block_size: int, use_v1: bool,
use_mla: bool) -> str:
if selected_backend != _Backend.IPEX:
logger.info("Cannot use %s backend on XPU.", selected_backend)
logger.info("Using IPEX attention backend.")

View File

@ -56,7 +56,8 @@ class CacheEngine:
model_config.dtype,
cache_config.cache_dtype,
self.block_size,
model_config.is_attention_free)
model_config.is_attention_free,
use_mla=model_config.use_mla)
# Initialize the cache.
self.gpu_cache = self._allocate_kv_cache(

View File

@ -1066,6 +1066,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.kv_cache_dtype,
self.block_size,
self.model_config.is_attention_free,
use_mla=self.model_config.use_mla,
) if needs_attn_backend else None
if self.attn_backend:
self.attn_state = self.attn_backend.get_state_cls()(
@ -1467,8 +1468,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
dtype=self.model_config.dtype,
device=self.device)
with self.attn_state.graph_capture(max_batch_size), graph_capture(
self.device) as graph_capture_context:
with self.attn_state.graph_capture(
max_batch_size, input_positions), graph_capture(
self.device) as graph_capture_context:
# NOTE: Capturing the largest batch size first may help reduce the
# memory usage of CUDA graph.
for virtual_engine in range(
@ -1973,7 +1975,8 @@ class CUDAGraphRunner(nn.Module):
# Copy the input tensors to the input buffers.
self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
self.input_buffers["positions"].copy_(positions, non_blocking=True)
if positions is not None:
self.input_buffers["positions"].copy_(positions, non_blocking=True)
if self.backend_name != "NO_ATTENTION":
self.input_buffers["slot_mapping"].copy_(