mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
Compare commits
12 Commits
v0.10.0rc2
...
mla_cuda_g
Author | SHA1 | Date | |
---|---|---|---|
0a02744dc8 | |||
984ffddda6 | |||
135c404fbb | |||
7241acbd64 | |||
2b140debbb | |||
2326814c11 | |||
534cd0006d | |||
aa19f297d2 | |||
4880a43d20 | |||
3895bba85a | |||
f23d126a07 | |||
ec8c1cf732 |
@ -28,6 +28,11 @@ void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
|
||||
const std::string& kv_cache_dtype,
|
||||
torch::Tensor& k_scale, torch::Tensor& v_scale);
|
||||
|
||||
void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe,
|
||||
torch::Tensor& kv_cache, torch::Tensor& slot_mapping,
|
||||
const std::string& kv_cache_dtype,
|
||||
torch::Tensor& scale);
|
||||
|
||||
// Just for unittest
|
||||
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
|
||||
const double scale, const std::string& kv_cache_dtype);
|
||||
|
@ -245,6 +245,51 @@ __global__ void reshape_and_cache_flash_kernel(
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
|
||||
__global__ void concat_and_cache_mla_kernel(
|
||||
const scalar_t* __restrict__ kv_c, // [num_tokens, kv_lora_rank]
|
||||
const scalar_t* __restrict__ k_pe, // [num_tokens, pe_dim]
|
||||
cache_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank
|
||||
// + pe_dim)]
|
||||
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
||||
const int block_stride, //
|
||||
const int kv_c_stride, //
|
||||
const int k_pe_stride, //
|
||||
const int kv_lora_rank, //
|
||||
const int pe_dim, //
|
||||
const int block_size, //
|
||||
const float* scale //
|
||||
) {
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
const int64_t slot_idx = slot_mapping[token_idx];
|
||||
// NOTE: slot_idx can be -1 if the token is padded
|
||||
if (slot_idx < 0) {
|
||||
return;
|
||||
}
|
||||
const int64_t block_idx = slot_idx / block_size;
|
||||
const int64_t block_offset = slot_idx % block_size;
|
||||
|
||||
auto copy = [&](const scalar_t* __restrict__ src, cache_t* __restrict__ dst,
|
||||
int src_stride, int dst_stride, int size, int offset) {
|
||||
for (int i = threadIdx.x; i < size; i += blockDim.x) {
|
||||
const int64_t src_idx = token_idx * src_stride + i;
|
||||
const int64_t dst_idx = block_idx * block_stride +
|
||||
block_offset * (kv_lora_rank + pe_dim) + i +
|
||||
offset;
|
||||
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
|
||||
dst[dst_idx] = src[src_idx];
|
||||
} else {
|
||||
dst[dst_idx] =
|
||||
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(src[src_idx], *scale);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
copy(kv_c, kv_cache, kv_c_stride, block_stride, kv_lora_rank, 0);
|
||||
copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank);
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
// KV_T is the stored data type of kv-cache.
|
||||
@ -343,6 +388,56 @@ void reshape_and_cache_flash(
|
||||
CALL_RESHAPE_AND_CACHE_FLASH);
|
||||
}
|
||||
|
||||
// KV_T is the stored data type of kv-cache.
|
||||
// CACHE_T is the data type of key and value tensors.
|
||||
// KV_DTYPE is the real data type of kv-cache.
|
||||
#define CALL_CONCAT_AND_CACHE_MLA(KV_T, CACHE_T, KV_DTYPE) \
|
||||
vllm::concat_and_cache_mla_kernel<KV_T, CACHE_T, KV_DTYPE> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<KV_T*>(kv_c.data_ptr()), \
|
||||
reinterpret_cast<KV_T*>(k_pe.data_ptr()), \
|
||||
reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
|
||||
slot_mapping.data_ptr<int64_t>(), block_stride, kv_c_stride, \
|
||||
k_pe_stride, kv_lora_rank, pe_dim, block_size, \
|
||||
reinterpret_cast<const float*>(scale.data_ptr()));
|
||||
|
||||
void concat_and_cache_mla(
|
||||
torch::Tensor& kv_c, // [num_tokens, kv_lora_rank]
|
||||
torch::Tensor& k_pe, // [num_tokens, pe_dim]
|
||||
torch::Tensor& kv_cache, // [num_blocks, block_size, (kv_lora_rank +
|
||||
// pe_dim)]
|
||||
torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens]
|
||||
const std::string& kv_cache_dtype, torch::Tensor& scale) {
|
||||
// NOTE(woosuk): In vLLM V1, key.size(0) can be different from
|
||||
// slot_mapping.size(0) because of padding for CUDA graphs.
|
||||
// In vLLM V0, key.size(0) is always equal to slot_mapping.size(0) because
|
||||
// both include padding.
|
||||
// In vLLM V1, however, key.size(0) can be larger than slot_mapping.size(0)
|
||||
// since key includes padding for CUDA graphs, while slot_mapping does not.
|
||||
// In this case, slot_mapping.size(0) represents the actual number of tokens
|
||||
// before padding.
|
||||
// For compatibility with both cases, we use slot_mapping.size(0) as the
|
||||
// number of tokens.
|
||||
int num_tokens = slot_mapping.size(0);
|
||||
int kv_lora_rank = kv_c.size(1);
|
||||
int pe_dim = k_pe.size(1);
|
||||
int block_size = kv_cache.size(1);
|
||||
|
||||
TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim);
|
||||
|
||||
int kv_c_stride = kv_c.stride(0);
|
||||
int k_pe_stride = k_pe.stride(0);
|
||||
int block_stride = kv_cache.stride(0);
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(kv_lora_rank, 512));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(kv_c));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
|
||||
CALL_CONCAT_AND_CACHE_MLA);
|
||||
}
|
||||
|
||||
namespace vllm {
|
||||
|
||||
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
||||
|
@ -463,6 +463,15 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
||||
cache_ops.impl("reshape_and_cache_flash", torch::kCUDA,
|
||||
&reshape_and_cache_flash);
|
||||
|
||||
// Concat kv_c and k_pe and cache them.
|
||||
cache_ops.def(
|
||||
"concat_and_cache_mla(Tensor kv_c, Tensor k_pe,"
|
||||
" Tensor! kv_cache,"
|
||||
" Tensor slot_mapping,"
|
||||
" str kv_cache_dtype,"
|
||||
" Tensor scale) -> ()");
|
||||
cache_ops.impl("concat_and_cache_mla", torch::kCUDA, &concat_and_cache_mla);
|
||||
|
||||
// Convert the key and value cache to fp8 data type.
|
||||
cache_ops.def(
|
||||
"convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, "
|
||||
|
89
tests/kernels/test_triton_decode_attention.py
Normal file
89
tests/kernels/test_triton_decode_attention.py
Normal file
@ -0,0 +1,89 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
|
||||
|
||||
|
||||
def cdiv(a, b):
|
||||
return (a + b - 1) // b
|
||||
|
||||
|
||||
@pytest.mark.parametrize("B", [3, 5])
|
||||
@pytest.mark.parametrize("L", [1027, 1025])
|
||||
@pytest.mark.parametrize("H_Q", [32])
|
||||
@pytest.mark.parametrize("H_KV", [32, 8])
|
||||
@pytest.mark.parametrize("D_QK", [128, 192, 576])
|
||||
@pytest.mark.parametrize("D_V", [128, 512])
|
||||
@pytest.mark.parametrize("CACHE_SIZE", [16384])
|
||||
@pytest.mark.parametrize("PAGE_SIZE", [1, 16])
|
||||
def test_decode_attention(B, L, H_Q, H_KV, D_QK, D_V, CACHE_SIZE, PAGE_SIZE):
|
||||
assert CACHE_SIZE % PAGE_SIZE == 0
|
||||
dtype = torch.bfloat16
|
||||
seq_len = L # This represents the number of tokens already in the sequence
|
||||
sm_scale = 1.0 / (D_QK**0.5)
|
||||
num_kv_splits = 8
|
||||
|
||||
num_pages_per_batch = cdiv(seq_len, PAGE_SIZE)
|
||||
req_to_page = torch.randint(0,
|
||||
CACHE_SIZE // PAGE_SIZE,
|
||||
(B, num_pages_per_batch, 1),
|
||||
device="cuda")
|
||||
req_to_token = req_to_page * PAGE_SIZE
|
||||
req_to_token = req_to_token.expand(B, num_pages_per_batch, PAGE_SIZE)
|
||||
req_to_token = req_to_token + torch.arange(PAGE_SIZE, device="cuda").view(
|
||||
1, 1, -1)
|
||||
req_to_token = req_to_token.view(B, -1)
|
||||
req_to_token = req_to_token[:, :seq_len].contiguous()
|
||||
|
||||
# q represents the new token being generated, one per batch
|
||||
q = torch.randn(B, H_Q, D_QK, dtype=dtype, device="cuda")
|
||||
|
||||
# k_buffer and v_buffer represent all previous tokens
|
||||
# Page size is 1.
|
||||
k_buffer = torch.randn(CACHE_SIZE, H_KV, D_QK, dtype=dtype, device="cuda")
|
||||
v_buffer = torch.randn(CACHE_SIZE, H_KV, D_V, dtype=dtype, device="cuda")
|
||||
|
||||
# o will have the same shape as q
|
||||
o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda")
|
||||
|
||||
b_seq_len = torch.full((B, ), seq_len, device="cuda")
|
||||
|
||||
attn_logits = torch.empty(
|
||||
(B, H_Q, num_kv_splits, D_V + 1),
|
||||
dtype=torch.float32,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
# Call the original implementation.
|
||||
decode_attention_fwd(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
o,
|
||||
req_to_token,
|
||||
b_seq_len,
|
||||
attn_logits,
|
||||
num_kv_splits,
|
||||
sm_scale,
|
||||
)
|
||||
|
||||
# Page size can be larger than 1.
|
||||
k_buffer = k_buffer.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_QK)
|
||||
v_buffer = v_buffer.view(CACHE_SIZE // PAGE_SIZE, PAGE_SIZE, H_KV, D_V)
|
||||
|
||||
o1 = torch.zeros_like(o)
|
||||
|
||||
decode_attention_fwd(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
o1,
|
||||
req_to_page,
|
||||
b_seq_len,
|
||||
attn_logits,
|
||||
num_kv_splits,
|
||||
sm_scale,
|
||||
PAGE_SIZE,
|
||||
)
|
||||
|
||||
assert torch.allclose(o, o1)
|
@ -980,6 +980,19 @@ def reshape_and_cache_flash(
|
||||
v_scale)
|
||||
|
||||
|
||||
def concat_and_cache_mla(
|
||||
kv_c: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
kv_cache_dtype: str,
|
||||
scale: torch.Tensor,
|
||||
) -> None:
|
||||
torch.ops._C_cache_ops.concat_and_cache_mla(kv_c, k_pe, kv_cache,
|
||||
slot_mapping, kv_cache_dtype,
|
||||
scale)
|
||||
|
||||
|
||||
def copy_blocks(key_caches: List[torch.Tensor],
|
||||
value_caches: List[torch.Tensor],
|
||||
block_mapping: torch.Tensor) -> None:
|
||||
|
@ -168,7 +168,8 @@ class AttentionState(ABC, Generic[T]):
|
||||
|
||||
@abstractmethod
|
||||
@contextmanager
|
||||
def graph_capture(self, max_batch_size: int):
|
||||
def graph_capture(self, max_batch_size: int,
|
||||
positions: Optional[torch.Tensor]):
|
||||
"""Context manager used when capturing CUDA graphs."""
|
||||
yield
|
||||
|
||||
@ -268,9 +269,25 @@ class AttentionImpl(ABC, Generic[T]):
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
query: torch.Tensor, # For MLA hidden_states_or_cq
|
||||
key: torch.Tensor, # For MLA kv_c_normed
|
||||
value: torch.Tensor, # For MLA k_pe
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: T,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MLAAttentionImpl(AttentionImpl):
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
hidden_states_or_cq: torch.Tensor,
|
||||
kv_c_normed: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: T,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
|
@ -213,7 +213,10 @@ class FlashInferState(AttentionState):
|
||||
return self._decode_wrapper
|
||||
|
||||
@contextmanager
|
||||
def graph_capture(self, max_batch_size: int):
|
||||
def graph_capture(self, max_batch_size: int,
|
||||
positions: Optional[torch.Tensor]):
|
||||
assert positions is None
|
||||
|
||||
self._is_graph_capturing = True
|
||||
self._graph_decode_wrapper = None
|
||||
self._graph_slot_mapping = torch.full((max_batch_size, ),
|
||||
|
0
vllm/attention/backends/mla/__init__.py
Normal file
0
vllm/attention/backends/mla/__init__.py
Normal file
364
vllm/attention/backends/mla/utils.py
Normal file
364
vllm/attention/backends/mla/utils.py
Normal file
@ -0,0 +1,364 @@
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import (AttentionLayer,
|
||||
AttentionMetadata,
|
||||
MLAAttentionImpl)
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class MLAMetadataCommon(AttentionMetadata):
|
||||
# Input positions for rotrary embeddings since for MLA the rotarty
|
||||
# position embeddings are applied inside the attention backend
|
||||
input_positions: torch.Tensor
|
||||
|
||||
|
||||
class MLACommonImpl(MLAAttentionImpl):
|
||||
"""
|
||||
Common class for implementing repeated parts
|
||||
|
||||
Main reference: DeepseekV2 paper, and FlashInfer Implementation
|
||||
(https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).
|
||||
|
||||
Deepseek's MLA attention works the following way:
|
||||
* Use a single latent vector to represent the entire KV cache.
|
||||
* The attention "simulates" a multi-head attention, while the compute is
|
||||
similar to multi-query attention.
|
||||
* The dataflow is as follows,
|
||||
|
||||
* B: batch/sequence length
|
||||
* H: hidden size
|
||||
* N: number of attention heads
|
||||
* Lq: latent dimension for Q
|
||||
* Lkv: latent dimension for K/V
|
||||
* P: nope dimension, P+R is the actual head_dim in common attention.
|
||||
* R: rope dimension, this slide of the head_dim goes through rope.
|
||||
* V: V head dim.
|
||||
* kv_c: latent/compressed KV
|
||||
* q_c: latent/compressed Q
|
||||
|
||||
#
|
||||
# Outside the MLA attention backend
|
||||
#
|
||||
|
||||
1. The hidden states (B, H) are projected down into cq (B, Lq) and
|
||||
kv_c_k_pe (B, Lkv+R).
|
||||
2. The kv_c_k_pe is split into kv_c (B, Lkv) and k_pe (B, R). cq
|
||||
and kv_c are normalized.
|
||||
|
||||
#
|
||||
# Inside the MLA attention backend
|
||||
#
|
||||
|
||||
* if prefill:
|
||||
|
||||
3. The q_c is then projected up into the multi-head version.
|
||||
* q_c goes from (B, Lq) to (B, N, (P+R)), which is split into q_nope
|
||||
(B, N, P) and q_pe (B, N, R).
|
||||
4. q_pe, k_pe are then passed through rotary embeddings.
|
||||
5. kv_c and k_pe are concatenated and inserted into the cache
|
||||
6. The kv_c is then projected up into the multi-head version.
|
||||
* kv_c goes from (B, Lkv) to (B, N, (P+V)) which has the nope
|
||||
dimensions for K and V, which is split into k_nope (B, N, P)
|
||||
and v (B, N, V).
|
||||
7. q (B, N, (P+R)) and k (B, N, (P+R)) matrices are assembled from
|
||||
q_nope, q_pe, k_nope, k_pe.
|
||||
8. Attention is computued with q, k, v.
|
||||
9. The attention computation returns (B, N, V), which is projected back
|
||||
to (B, H) using out projection.
|
||||
|
||||
* if decode:
|
||||
|
||||
3. Here's the change, we do not perform up the full up projection for
|
||||
q_c, and there is no up projection at all for kv_c. This is
|
||||
achieved by the technique of "weight absorption". The paper says
|
||||
"Fortunately, due to the associative law of matrix multiplication,
|
||||
we can absorb WUK into WUQ, and WUV into WO"
|
||||
* The q up projection turns (B, Lq) into (B, N, (P+R)), we split it
|
||||
into W_UQ (Lq, N, P) and W_QR (Lq, N, R).
|
||||
* The kv_c up projection turns (B, Lkv) into (B, N, (P+V)), we split
|
||||
it into W_UK (Lkv, N, P) and W_UV (Lkv, N, V).
|
||||
* The out projection shape W_O (N*V, H) turns (B, N, V) into (B, H).
|
||||
* We can precompute the product of W_UQ and W_UK into
|
||||
W_UQ_UK (Lq, N, Lkv), which is possible due to QK^T operation in
|
||||
attention.
|
||||
* We can precompute the product of W_UV and W_O into
|
||||
W_UV_O (N, Lkv, H), which is possible due to V@O as the
|
||||
"epilogue" of attention
|
||||
4. We still need to compute q_pe (B, N, R) by applying W_QR to q_latent.
|
||||
5. q_pe, k_pe are then passed through rotary embeddings.
|
||||
6. kv_c and k_pe are concatenated and inserted into the cache
|
||||
7. By applying W_UQ_UK to q_latent, we have the new q_nope of shape
|
||||
(B, N, Lkv).
|
||||
8. q (B, N, (Lkv+R)), k (B, (Lkv+R)) are assembled from q_nope, q_pe,
|
||||
kv_a, k_pe. v (B, Lkv) is exactly the same vector as kv_a.
|
||||
9. The attention is computed with q, k, v. Note that we just performed
|
||||
a MQA attention with (LKv+R) as our head dim.
|
||||
10. The KV cache is updated using the new entries k (B, N, (Lkv+R)),
|
||||
which included the v and rope values.
|
||||
11. The attention computation returns (B, N, Lkv), which is projected
|
||||
back to (B, H) using W_UV_O.
|
||||
|
||||
From @tsu-bin's calculation, we only want to use the absorption technique
|
||||
for decode. The prefill algorithm should still use the up-projected MHA
|
||||
for less flops and memory usage.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]],
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
# MLA Specific Arguments
|
||||
q_lora_rank: Optional[int],
|
||||
kv_lora_rank: int,
|
||||
qk_nope_head_dim: int,
|
||||
qk_rope_head_dim: int,
|
||||
qk_head_dim: int,
|
||||
v_head_dim: int,
|
||||
rotary_emb: RotaryEmbedding,
|
||||
# q_proj should be q_b_proj if q_lora_rank is not None, but from an
|
||||
# attention backend perspective we rely on the layer to pass in the
|
||||
# correct matrix
|
||||
q_proj: Optional[ColumnParallelLinear],
|
||||
kv_b_proj: ColumnParallelLinear,
|
||||
o_proj: RowParallelLinear,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.qk_nope_head_dim = qk_nope_head_dim
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
self.qk_head_dim = qk_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
|
||||
self.rotary_emb = rotary_emb
|
||||
self.q_proj = q_proj
|
||||
self.kv_b_proj = kv_b_proj
|
||||
self.o_proj = o_proj
|
||||
|
||||
def _v_up_proj_and_o_proj(self, x):
|
||||
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
|
||||
return self.o_proj_absorbed(
|
||||
x.reshape(-1, self.num_heads * self.kv_lora_rank))[0]
|
||||
else:
|
||||
x = torch.einsum("bnl,lnv->bnv", x, self.W_UV)
|
||||
return self.o_proj(x.reshape(-1,
|
||||
self.num_heads * self.v_head_dim))[0]
|
||||
|
||||
def _q_proj_and_k_up_proj(self, x):
|
||||
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
|
||||
return torch.matmul(x, self.W_Q_UK)\
|
||||
.view(-1, self.num_heads, self.kv_lora_rank)
|
||||
else:
|
||||
x = torch.matmul(x, self.W_Q)\
|
||||
.view(-1, self.num_heads, self.qk_nope_head_dim)
|
||||
return torch.einsum("bnp,lnp->bnl", x, self.W_UK)\
|
||||
.view(-1, self.num_heads, self.kv_lora_rank)
|
||||
|
||||
def process_weights_after_loading(self):
|
||||
kv_b_proj_weight = self.kv_b_proj.weight.T
|
||||
assert kv_b_proj_weight.shape == (
|
||||
self.kv_lora_rank,
|
||||
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), (
|
||||
f"{kv_b_proj_weight.shape=}, "
|
||||
f"{self.kv_lora_rank=}, "
|
||||
f"{self.num_heads=}, "
|
||||
f"{self.qk_nope_head_dim=}, "
|
||||
f"{self.v_head_dim=}")
|
||||
kv_b_proj_weight = kv_b_proj_weight.view(
|
||||
self.kv_lora_rank,
|
||||
self.num_heads,
|
||||
self.qk_nope_head_dim + self.v_head_dim,
|
||||
)
|
||||
|
||||
W_UK, W_UV = kv_b_proj_weight.split(
|
||||
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
|
||||
q_proj = self.q_proj.weight.T\
|
||||
.view(-1, self.num_heads, self.qk_head_dim)
|
||||
|
||||
# can be W_Q or W_UQ depending q_lora_rank, the former if
|
||||
# q_lora_rank is None, the latter otherwise. From the Attention backend
|
||||
# perspective though we call these both W_Q and rely on the layer
|
||||
# to pass in the correct matrix
|
||||
W_Q = q_proj[..., :self.qk_nope_head_dim]
|
||||
self.W_QR = q_proj[..., self.qk_nope_head_dim:]\
|
||||
.flatten(start_dim=1).contiguous()
|
||||
|
||||
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
|
||||
#
|
||||
# Perform matrix-absorption following
|
||||
# https://github.com/flashinfer-ai/flashinfer/pull/551
|
||||
# for decode, as a result we end up with absorbed weights for decode
|
||||
# and another copy of raw weights for prefill.
|
||||
#
|
||||
self.W_UK, self.W_UV = kv_b_proj_weight.split(
|
||||
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
# We absorb `W_UK` into `W_Q` resulting in either W_Q_UK or W_UQ_UK
|
||||
# depending q_lora_rank, the former if q_lora_rank is None, the
|
||||
# latter otherwise
|
||||
# basically if q_lora_rank is none we are absorbing into q_proj
|
||||
# instead of UQ
|
||||
self.W_Q_UK = torch.einsum("qnd,lnd -> qnl", W_Q, W_UK)\
|
||||
.flatten(start_dim=1).contiguous()
|
||||
|
||||
W_O = self.o_proj.weight\
|
||||
.view(-1, self.num_heads, self.v_head_dim)
|
||||
self.W_UV_O = torch.einsum("lnd,hnd -> nlh", W_UV, W_O)\
|
||||
.flatten(start_dim=0, end_dim=1).contiguous()
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.o_proj_absorbed = RowParallelLinear(
|
||||
self.W_UV_O.shape[0] * tp_size,
|
||||
self.W_UV_O.shape[1],
|
||||
bias=False,
|
||||
# TODO(lucas) figure out how to properly forward quant_method
|
||||
#quant_config=self.o_proj.quant_method,
|
||||
)
|
||||
|
||||
self.o_proj_absorbed.weight = torch.nn.Parameter(self.W_UV_O.T)
|
||||
else:
|
||||
self.W_UV = W_UV
|
||||
self.W_UK = W_UK
|
||||
self.W_Q = W_Q.flatten(start_dim=1)
|
||||
|
||||
@abstractmethod
|
||||
def _forward_prefill(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
kv_c_normed: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
attn_metadata: MLAMetadataCommon,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def _forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: MLAMetadataCommon,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
hidden_states_or_q_c: torch.Tensor, # query in unified attn
|
||||
k_c_normed: torch.Tensor, # key in unified attn
|
||||
k_pe: torch.Tensor, # value in unified attn
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: MLAMetadataCommon,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if output is not None:
|
||||
raise NotImplementedError(
|
||||
"output is not yet supported for MLAImplBase")
|
||||
|
||||
is_decode = attn_metadata.decode_metadata is not None
|
||||
is_prefill = attn_metadata.prefill_metadata is not None
|
||||
|
||||
if (is_decode and is_prefill):
|
||||
raise NotImplementedError(
|
||||
"chunked prefill is not supported for MLAImplBase")
|
||||
|
||||
# Restore head dim (for rotary embedding)
|
||||
k_pe = k_pe.unsqueeze(1)
|
||||
|
||||
if is_decode:
|
||||
q_nope = self._q_proj_and_k_up_proj(hidden_states_or_q_c)
|
||||
q_pe = torch.matmul(hidden_states_or_q_c, self.W_QR)\
|
||||
.view(-1, self.num_heads, self.qk_rope_head_dim)
|
||||
q_pe, k_pe = \
|
||||
self.rotary_emb(attn_metadata.input_positions, q_pe, k_pe)
|
||||
else:
|
||||
assert is_prefill
|
||||
q = self.q_proj(hidden_states_or_q_c)[0]\
|
||||
.view(-1, self.num_heads, self.qk_head_dim)
|
||||
|
||||
# TODO(lucas): there must be a nicer way to write this line
|
||||
q[..., self.qk_nope_head_dim:], k_pe = \
|
||||
self.rotary_emb(
|
||||
attn_metadata.input_positions,
|
||||
q[..., self.qk_nope_head_dim:], k_pe)
|
||||
|
||||
# write the latent and rope to kv cache
|
||||
if kv_cache.numel() > 0:
|
||||
ops.concat_and_cache_mla(
|
||||
k_c_normed,
|
||||
k_pe.squeeze(1),
|
||||
kv_cache,
|
||||
attn_metadata.slot_mapping.flatten(),
|
||||
kv_cache_dtype=self.kv_cache_dtype,
|
||||
scale=layer._k_scale,
|
||||
)
|
||||
|
||||
if attn_metadata.prefill_metadata is not None:
|
||||
return self._forward_prefill(q, k_c_normed, k_pe, attn_metadata)
|
||||
|
||||
if attn_metadata.decode_metadata is not None:
|
||||
return self._forward_decode(q_nope, q_pe, kv_cache, attn_metadata)
|
||||
|
||||
# Optional common flash-attn based prefill
|
||||
def _forward_prefill_flash(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
k_c_normed: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
seq_start_loc: torch.Tensor,
|
||||
max_prefill_seq_len: int,
|
||||
) -> torch.Tensor:
|
||||
|
||||
kv_nope = self.kv_b_proj(k_c_normed)[0]\
|
||||
.view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
||||
k_nope, v = kv_nope\
|
||||
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
|
||||
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
|
||||
|
||||
# For MLA the v head dim is smaller than qk head dim so we pad out
|
||||
# v with 0s to match the qk head dim
|
||||
v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
|
||||
value=0)
|
||||
|
||||
attn_output = flash_attn_varlen_func(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v_padded,
|
||||
cu_seqlens_q=seq_start_loc,
|
||||
cu_seqlens_k=seq_start_loc,
|
||||
max_seqlen_q=max_prefill_seq_len,
|
||||
max_seqlen_k=max_prefill_seq_len,
|
||||
softmax_scale=self.scale,
|
||||
causal=True,
|
||||
)
|
||||
attn_output = attn_output\
|
||||
.view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\
|
||||
.reshape(-1, self.num_heads * v.shape[-1])
|
||||
|
||||
return self.o_proj(attn_output)[0]
|
742
vllm/attention/backends/triton_mla.py
Normal file
742
vllm/attention/backends/triton_mla.py
Normal file
@ -0,0 +1,742 @@
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from itertools import accumulate
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
from vllm.multimodal import MultiModalPlaceholderMap
|
||||
|
||||
try:
|
||||
from flashinfer import BatchDecodeMlaWithPagedKVCacheWrapper
|
||||
FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
|
||||
except ImportError:
|
||||
BatchDecodeMlaWithPagedKVCacheWrapper = None
|
||||
FLASHINFER_WORKSPACE_BUFFER_SIZE = 0
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||
AttentionMetadata,
|
||||
AttentionMetadataBuilder,
|
||||
AttentionState, AttentionType)
|
||||
from vllm.attention.backends.mla.utils import MLACommonImpl, MLAMetadataCommon
|
||||
from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
|
||||
compute_slot_mapping_start_idx,
|
||||
is_block_tables_empty)
|
||||
from vllm.attention.ops.paged_attn import PagedAttention
|
||||
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
|
||||
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
|
||||
ModelInputForGPUWithSamplingMetadata)
|
||||
|
||||
|
||||
class TritonMLABackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "TRITON_MLA"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["TritonMLAImpl"]:
|
||||
return TritonMLAImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
||||
return TritonMLAMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> Type["TritonMLAMetadataBuilder"]:
|
||||
return TritonMLAMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> Type["TritonMLAState"]:
|
||||
return TritonMLAState
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int, # assumed to be 1 for MLA
|
||||
kv_lora_rank: int, # passed via head_size
|
||||
) -> Tuple[int, ...]:
|
||||
# TODO(lucas): remove hardcoding k_pe size as 1/8th of kv_lora_rank
|
||||
k_pe_size = kv_lora_rank // 8
|
||||
return (num_blocks, block_size, kv_lora_rank + k_pe_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: torch.Tensor,
|
||||
) -> None:
|
||||
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: torch.Tensor,
|
||||
) -> None:
|
||||
PagedAttention.copy_blocks(kv_caches, src_to_dists)
|
||||
|
||||
@staticmethod
|
||||
def get_supported_head_sizes() -> List[int]:
|
||||
return [512]
|
||||
|
||||
|
||||
class TritonMLAState(AttentionState):
|
||||
|
||||
def __init__(self, runner):
|
||||
self.runner = runner
|
||||
self._is_graph_capturing = False
|
||||
|
||||
@contextmanager
|
||||
def graph_capture(self, max_batch_size: int,
|
||||
positions: Optional[torch.Tensor]):
|
||||
self._is_graph_capturing = True
|
||||
|
||||
self._graph_slot_mapping = torch.full((max_batch_size, ),
|
||||
PAD_SLOT_ID,
|
||||
dtype=torch.long,
|
||||
device=self.runner.device)
|
||||
self._graph_seq_lens = torch.ones(max_batch_size,
|
||||
dtype=torch.int32,
|
||||
device=self.runner.device)
|
||||
self._graph_block_tables = torch.from_numpy(
|
||||
self.runner.graph_block_tables).to(device=self.runner.device)
|
||||
|
||||
assert positions is not None
|
||||
self._positions = positions
|
||||
|
||||
yield
|
||||
|
||||
self._is_graph_capturing = False
|
||||
del self._graph_slot_mapping
|
||||
del self._graph_seq_lens
|
||||
del self._graph_block_tables
|
||||
|
||||
def graph_clone(self, batch_size: int):
|
||||
assert self._is_graph_capturing
|
||||
return self.__class__(self.runner)
|
||||
|
||||
def graph_capture_get_metadata_for_batch(
|
||||
self, batch_size: int, is_encoder_decoder_model: bool = False):
|
||||
assert self._is_graph_capturing
|
||||
|
||||
attn_metadata = self.runner.attn_backend.make_metadata(
|
||||
num_prefills=0,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=batch_size,
|
||||
slot_mapping=self._graph_slot_mapping[:batch_size],
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=True,
|
||||
seq_lens=None,
|
||||
seq_lens_tensor=self._graph_seq_lens[:batch_size],
|
||||
max_query_len=1,
|
||||
max_decode_query_len=1,
|
||||
max_prefill_seq_len=0,
|
||||
max_decode_seq_len=self.runner.max_seq_len_to_capture,
|
||||
query_start_loc=None,
|
||||
seq_start_loc=None,
|
||||
context_lens_tensor=None,
|
||||
block_tables=self._graph_block_tables[:batch_size],
|
||||
use_cuda_graph=True,
|
||||
input_positions=self._positions[:batch_size],
|
||||
head_dim=self.runner.model_config.get_head_size())
|
||||
|
||||
if is_encoder_decoder_model:
|
||||
raise NotImplementedError(
|
||||
"TritonMLAState does not support encoder/decoder yet")
|
||||
|
||||
return attn_metadata
|
||||
|
||||
def get_graph_input_buffers(self,
|
||||
attn_metadata,
|
||||
is_encoder_decoder_model: bool = False):
|
||||
input_buffers = {
|
||||
"slot_mapping": attn_metadata.slot_mapping,
|
||||
"seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor,
|
||||
"block_tables": attn_metadata.decode_metadata.block_tables,
|
||||
}
|
||||
if is_encoder_decoder_model:
|
||||
raise NotImplementedError(
|
||||
"TritonMLAState does not support encoder/decoder yet")
|
||||
|
||||
return input_buffers
|
||||
|
||||
def prepare_graph_input_buffers(self,
|
||||
input_buffers,
|
||||
attn_metadata,
|
||||
is_encoder_decoder_model: bool = False):
|
||||
input_buffers["seq_lens_tensor"].copy_(
|
||||
attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True)
|
||||
input_buffers["block_tables"].copy_(
|
||||
attn_metadata.decode_metadata.block_tables, non_blocking=True)
|
||||
if is_encoder_decoder_model:
|
||||
raise NotImplementedError(
|
||||
"TritonMLAState does not support encoder/decoder yet")
|
||||
|
||||
def begin_forward(self, model_input):
|
||||
return
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class TritonMLAMetadata(MLAMetadataCommon):
|
||||
"""Metadata for TritonMLAMetadata.
|
||||
|
||||
NOTE: Any python object stored here is not updated when it is
|
||||
cuda-graph replayed. If you have values that need to be changed
|
||||
dynamically, it should be stored in tensor. The tensor has to be
|
||||
updated from `CUDAGraphRunner.forward` API.
|
||||
"""
|
||||
# (batch_size,). The sequence length per sequence. Sequence length means
|
||||
# the computed tokens + new tokens None if it is a decoding.
|
||||
seq_lens: Optional[List[int]]
|
||||
# seq_lens stored as a tensor.
|
||||
seq_lens_tensor: Optional[torch.Tensor]
|
||||
|
||||
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
# |- tokenA -|......................|-- newTokens ---|
|
||||
# |---------- context_len ----------|
|
||||
# |-------------------- seq_len ---------------------|
|
||||
# |-- query_len ---|
|
||||
|
||||
# Maximum sequence length among prefill batch. 0 if there are decoding
|
||||
# requests only.
|
||||
max_prefill_seq_len: int
|
||||
# Maximum sequence length among decode batch. 0 if there are prefill
|
||||
# requests only.
|
||||
max_decode_seq_len: int
|
||||
# (batch_size,) A tensor of context lengths (tokens that are computed
|
||||
# so far).
|
||||
context_lens_tensor: Optional[torch.Tensor]
|
||||
|
||||
# (batch_size, max_blocks_per_seq).
|
||||
# Block addresses per sequence. (Seq id -> list of physical block)
|
||||
# E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
|
||||
# in the kv cache. Each block can contain up to block_size tokens.
|
||||
# 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
|
||||
# captured.
|
||||
block_tables: Optional[torch.Tensor]
|
||||
|
||||
# Whether or not if cuda graph is enabled.
|
||||
# Cuda-graph is currently enabled for decoding only.
|
||||
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
|
||||
|
||||
use_cuda_graph: bool
|
||||
|
||||
# Maximum query length in the batch.
|
||||
max_query_len: Optional[int] = None
|
||||
|
||||
# Max number of query tokens among request in the batch.
|
||||
max_decode_query_len: Optional[int] = None
|
||||
|
||||
# (batch_size + 1,). The cumulative subquery lengths of the sequences in
|
||||
# the batch, used to index into subquery. E.g., if the subquery length
|
||||
# is [4, 6], it is [0, 4, 10].
|
||||
query_start_loc: Optional[torch.Tensor] = None
|
||||
# (batch_size + 1,). The cumulative sequence lengths of the sequences in
|
||||
# the batch, used to index into sequence. E.g., if the sequence length is
|
||||
# [4, 6], it is [0, 4, 10].
|
||||
seq_start_loc: Optional[torch.Tensor] = None
|
||||
|
||||
_cached_prefill_metadata: Optional["TritonMLAMetadata"] = None
|
||||
_cached_decode_metadata: Optional["TritonMLAMetadata"] = None
|
||||
|
||||
num_prefill_tokens: int
|
||||
|
||||
num_kv_splits: int = 4 # TODO(lucas) add heuristic
|
||||
attn_logits: Optional[torch.Tensor] = None
|
||||
req_idx: Optional[torch.Tensor] = None
|
||||
|
||||
# The dimension of the attention heads
|
||||
head_dim: Optional[int] = None
|
||||
|
||||
def __post_init__(self):
|
||||
supported_head_sizes = TritonMLABackend.get_supported_head_sizes()
|
||||
if self.head_dim is not None and self.head_dim \
|
||||
not in supported_head_sizes:
|
||||
raise ValueError(
|
||||
f"Only {supported_head_sizes} are supported for head_dim,",
|
||||
f"received {self.head_dim}.")
|
||||
|
||||
@property
|
||||
def prefill_metadata(self) -> Optional["TritonMLAMetadata"]:
|
||||
if self.num_prefills == 0:
|
||||
return None
|
||||
|
||||
if self._cached_prefill_metadata is not None:
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
assert ((self.seq_lens is not None)
|
||||
or (self.encoder_seq_lens is not None))
|
||||
assert ((self.seq_lens_tensor is not None)
|
||||
or (self.encoder_seq_lens_tensor is not None))
|
||||
|
||||
# Compute some attn_metadata fields which default to None
|
||||
query_start_loc = (None if self.query_start_loc is None else
|
||||
self.query_start_loc[:self.num_prefills + 1])
|
||||
slot_mapping = (None if self.slot_mapping is None else
|
||||
self.slot_mapping[:self.num_prefill_tokens])
|
||||
seq_lens = (None if self.seq_lens is None else
|
||||
self.seq_lens[:self.num_prefills])
|
||||
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
||||
self.seq_lens_tensor[:self.num_prefills])
|
||||
seq_start_loc = (None if self.seq_start_loc is None else
|
||||
self.seq_start_loc[:self.num_prefills + 1])
|
||||
context_lens_tensor = (None if self.context_lens_tensor is None else
|
||||
self.context_lens_tensor[:self.num_prefills])
|
||||
block_tables = (None if self.block_tables is None else
|
||||
self.block_tables[:self.num_prefills])
|
||||
input_positions = (None if self.input_positions is None else
|
||||
self.input_positions[:self.num_prefill_tokens])
|
||||
|
||||
self._cached_prefill_metadata = TritonMLAMetadata(
|
||||
num_prefills=self.num_prefills,
|
||||
num_prefill_tokens=self.num_prefill_tokens,
|
||||
num_decode_tokens=0,
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=self.
|
||||
multi_modal_placeholder_index_maps,
|
||||
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_query_len=self.max_query_len,
|
||||
max_prefill_seq_len=self.max_prefill_seq_len,
|
||||
max_decode_query_len=0,
|
||||
max_decode_seq_len=0,
|
||||
query_start_loc=query_start_loc,
|
||||
seq_start_loc=seq_start_loc,
|
||||
context_lens_tensor=context_lens_tensor,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=False,
|
||||
input_positions=input_positions,
|
||||
head_dim=self.head_dim)
|
||||
return self._cached_prefill_metadata
|
||||
|
||||
@property
|
||||
def decode_metadata(self) -> Optional["TritonMLAMetadata"]:
|
||||
if self.num_decode_tokens == 0:
|
||||
return None
|
||||
|
||||
if self._cached_decode_metadata is not None:
|
||||
return self._cached_decode_metadata
|
||||
assert ((self.seq_lens_tensor is not None)
|
||||
or (self.encoder_seq_lens_tensor is not None))
|
||||
|
||||
# Compute some attn_metadata fields which default to None
|
||||
slot_mapping = (None if self.slot_mapping is None else
|
||||
self.slot_mapping[self.num_prefill_tokens:])
|
||||
seq_lens_tensor = (None if self.seq_lens_tensor is None else
|
||||
self.seq_lens_tensor[self.num_prefills:])
|
||||
block_tables = (None if self.block_tables is None else
|
||||
self.block_tables[self.num_prefills:])
|
||||
input_positions = (None if self.input_positions is None else
|
||||
self.input_positions[self.num_prefill_tokens:])
|
||||
|
||||
self._cached_decode_metadata = TritonMLAMetadata(
|
||||
num_prefills=0,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=self.num_decode_tokens,
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=True,
|
||||
seq_lens=None,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_decode_query_len=self.max_decode_query_len,
|
||||
max_query_len=self.max_query_len,
|
||||
max_prefill_seq_len=0,
|
||||
max_decode_seq_len=self.max_decode_seq_len,
|
||||
# Batch may be composed of prefill|decodes, adjust query start
|
||||
# indices to refer to the start of decodes. E.g.
|
||||
# in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
|
||||
query_start_loc=(self.query_start_loc[self.num_prefills:] -
|
||||
self.query_start_loc[self.num_prefills])
|
||||
if self.query_start_loc is not None else None,
|
||||
seq_start_loc=self.seq_start_loc[self.num_prefills:]
|
||||
if self.seq_start_loc is not None else None,
|
||||
context_lens_tensor=None,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=self.use_cuda_graph,
|
||||
input_positions=input_positions,
|
||||
head_dim=self.head_dim)
|
||||
return self._cached_decode_metadata
|
||||
|
||||
def advance_step(self,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
sampled_token_ids: Optional[torch.Tensor],
|
||||
block_size: int,
|
||||
num_seqs: int,
|
||||
num_queries: int,
|
||||
turn_prefills_into_decodes: bool = False):
|
||||
"""
|
||||
Update metadata in-place to advance one decode step.
|
||||
"""
|
||||
# When using cudagraph, the num_seqs is padded to the next captured
|
||||
# batch sized, but num_queries tracks the actual number of requests in
|
||||
# the batch. For --enforce-eager mode, num_seqs == num_queries
|
||||
if num_seqs != num_queries:
|
||||
assert num_seqs > num_queries
|
||||
assert self.use_cuda_graph
|
||||
|
||||
if turn_prefills_into_decodes:
|
||||
# When Mutli-Step is enabled with Chunked-Prefill, prefills and
|
||||
# decodes are scheduled together. In the first step, all the
|
||||
# prefills turn into decodes. This update reflects that
|
||||
# conversion.
|
||||
assert self.num_decode_tokens + self.num_prefills == num_seqs
|
||||
self.num_decode_tokens += self.num_prefills
|
||||
self.num_prefills = 0
|
||||
self.num_prefill_tokens = 0
|
||||
self.max_prefill_seq_len = 0
|
||||
self.max_query_len = 1
|
||||
|
||||
self.slot_mapping = self.slot_mapping[:num_seqs]
|
||||
else:
|
||||
assert self.seq_lens is not None
|
||||
assert self.max_decode_seq_len == max(self.seq_lens)
|
||||
|
||||
assert self.num_prefills == 0
|
||||
assert self.num_prefill_tokens == 0
|
||||
assert self.num_decode_tokens == num_seqs
|
||||
assert self.slot_mapping.shape == (num_seqs, )
|
||||
|
||||
assert self.seq_lens is not None
|
||||
assert len(self.seq_lens) == num_seqs
|
||||
assert self.seq_lens_tensor is not None
|
||||
assert self.seq_lens_tensor.shape == (num_seqs, )
|
||||
assert self.max_query_len == 1
|
||||
assert self.max_prefill_seq_len == 0
|
||||
|
||||
assert self.query_start_loc is not None
|
||||
assert self.query_start_loc.shape == (num_queries + 1, )
|
||||
assert self.seq_start_loc is not None
|
||||
assert self.seq_start_loc.shape == (num_seqs + 1, )
|
||||
|
||||
assert self.context_lens_tensor is not None
|
||||
assert self.context_lens_tensor.shape == (num_queries, )
|
||||
|
||||
assert self.block_tables is not None
|
||||
assert self.block_tables.shape[0] == num_seqs
|
||||
|
||||
# Update query lengths. Note that we update only queries and not seqs,
|
||||
# since tensors may be padded due to captured cuda graph batch size
|
||||
for i in range(num_queries):
|
||||
self.seq_lens[i] += 1
|
||||
self.max_decode_seq_len = max(self.seq_lens)
|
||||
|
||||
ops.advance_step_flashattn(num_seqs=num_seqs,
|
||||
num_queries=num_queries,
|
||||
block_size=block_size,
|
||||
input_tokens=model_input.input_tokens,
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
input_positions=model_input.input_positions,
|
||||
seq_lens=self.seq_lens_tensor,
|
||||
slot_mapping=self.slot_mapping,
|
||||
block_tables=self.block_tables)
|
||||
|
||||
|
||||
class TritonMLAMetadataBuilder(AttentionMetadataBuilder[TritonMLAMetadata]):
|
||||
|
||||
def __init__(self, input_builder: "ModelInputForGPUBuilder"):
|
||||
self.input_builder = input_builder
|
||||
self.runner = input_builder.runner
|
||||
self.sliding_window = input_builder.sliding_window
|
||||
self.block_size = input_builder.block_size
|
||||
|
||||
def prepare(self):
|
||||
self.slot_mapping: List[int] = []
|
||||
self.prefill_seq_lens: List[int] = []
|
||||
self.context_lens: List[int] = []
|
||||
self.block_tables: List[List[int]] = []
|
||||
self.curr_seq_lens: List[int] = []
|
||||
self.input_positions: List[int] = []
|
||||
self.multimodal_placeholder_maps: Dict[
|
||||
str,
|
||||
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
|
||||
self.num_prefills = 0
|
||||
self.num_prefill_tokens = 0
|
||||
self.num_decode_tokens = 0
|
||||
self.has_prefix_cache_hit = False
|
||||
|
||||
def _add_seq_group(
|
||||
self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
|
||||
chunked_prefill_enabled: bool, prefix_cache_hit: bool):
|
||||
"""Add a sequence group to the metadata. Specifically update/append
|
||||
1. context length.
|
||||
2. block table.
|
||||
3. slot mapping.
|
||||
"""
|
||||
is_prompt = inter_data.is_prompt
|
||||
block_tables = inter_data.block_tables
|
||||
|
||||
for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
|
||||
curr_sliding_window_block, input_positions) in zip(
|
||||
inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
|
||||
inter_data.orig_seq_lens, inter_data.seq_lens,
|
||||
inter_data.query_lens, inter_data.context_lens,
|
||||
inter_data.curr_sliding_window_blocks,
|
||||
inter_data.input_positions):
|
||||
self.input_positions.extend(input_positions)
|
||||
self.context_lens.append(context_len)
|
||||
if is_prompt:
|
||||
mm_maps = inter_data.multi_modal_placeholder_maps
|
||||
if mm_maps:
|
||||
for modality, placeholders in mm_maps.items():
|
||||
self.multimodal_placeholder_maps[modality].extend(
|
||||
placeholders)
|
||||
|
||||
self.num_prefills += 1
|
||||
self.num_prefill_tokens += token_len
|
||||
self.prefill_seq_lens.append(seq_len)
|
||||
else:
|
||||
self.num_decode_tokens += query_len
|
||||
self.curr_seq_lens.append(curr_seq_len)
|
||||
|
||||
# Compute block table.
|
||||
# TODO(sang): Combine chunked prefill and prefix caching by
|
||||
# only allowing multiple of block_size chunk size.
|
||||
# NOTE: This only works for oooooooxxx style attention.
|
||||
block_table = []
|
||||
if prefix_cache_hit:
|
||||
# NOTE(woosuk): For flash-attn, the block table should
|
||||
# include the entries for the incoming prefill tokens.
|
||||
block_table = block_tables[seq_id]
|
||||
elif ((chunked_prefill_enabled or not is_prompt)
|
||||
and block_tables is not None):
|
||||
if curr_sliding_window_block == 0:
|
||||
block_table = block_tables[seq_id]
|
||||
else:
|
||||
block_table = block_tables[seq_id][
|
||||
-curr_sliding_window_block:]
|
||||
self.block_tables.append(block_table)
|
||||
|
||||
# Compute slot mapping.
|
||||
is_profile_run = is_block_tables_empty(block_tables)
|
||||
start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
|
||||
context_len,
|
||||
self.sliding_window)
|
||||
compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
|
||||
seq_len, context_len, start_idx,
|
||||
self.block_size, inter_data.block_tables)
|
||||
|
||||
def _get_graph_runner_block_tables(
|
||||
self, num_seqs: int,
|
||||
block_tables: List[List[int]]) -> torch.Tensor:
|
||||
# The shape of graph_block_tables is
|
||||
# [max batch size, max context len // block size].
|
||||
max_batch_size, max_blocks = self.runner.graph_block_tables.shape
|
||||
assert max_batch_size >= num_seqs
|
||||
|
||||
graph_block_tables = self.runner.graph_block_tables[:num_seqs]
|
||||
for i, block_table in enumerate(block_tables):
|
||||
if block_table:
|
||||
num_blocks = len(block_table)
|
||||
if num_blocks <= max_blocks:
|
||||
graph_block_tables[i, :num_blocks] = block_table
|
||||
else:
|
||||
# It may be possible to have more blocks allocated due
|
||||
# to lookahead slots of multi-step, however, they are
|
||||
# not used anyway, so can be safely ignored.
|
||||
graph_block_tables[
|
||||
i, :max_blocks] = block_table[:max_blocks]
|
||||
|
||||
return torch.from_numpy(graph_block_tables).to(
|
||||
device=self.runner.device, non_blocking=True)
|
||||
|
||||
def build(self, seq_lens: List[int], query_lens: List[int],
|
||||
cuda_graph_pad_size: int, batch_size: int):
|
||||
"""Build attention metadata with on-device tensors.
|
||||
|
||||
Args:
|
||||
seq_lens: The maybe padded sequence lengths of the input sequences.
|
||||
query_lens: The query lengths of the input sequences.
|
||||
cuda_graph_pad_size: The padding size for cuda graph.
|
||||
-1 if cuda graph is not used.
|
||||
batch_size: The maybe padded batch size.
|
||||
"""
|
||||
prefix_cache_hit = any([
|
||||
inter_data.prefix_cache_hit
|
||||
for inter_data in self.input_builder.inter_data_list
|
||||
])
|
||||
for inter_data in self.input_builder.inter_data_list:
|
||||
self._add_seq_group(inter_data,
|
||||
self.input_builder.chunked_prefill_enabled,
|
||||
prefix_cache_hit)
|
||||
|
||||
device = self.runner.device
|
||||
use_captured_graph = cuda_graph_pad_size != -1
|
||||
|
||||
max_query_len = max(query_lens)
|
||||
decode_query_lens = query_lens[self.num_prefills:]
|
||||
if len(decode_query_lens) > 0:
|
||||
max_decode_query_len = max(decode_query_lens)
|
||||
else:
|
||||
max_decode_query_len = 1
|
||||
max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
|
||||
max_decode_seq_len = max(self.curr_seq_lens, default=0)
|
||||
num_decode_tokens = self.num_decode_tokens
|
||||
query_start_loc = list(accumulate(query_lens, initial=0))
|
||||
seq_start_loc = list(accumulate(seq_lens, initial=0))
|
||||
|
||||
num_seqs = len(seq_lens)
|
||||
if use_captured_graph:
|
||||
self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
|
||||
self.block_tables.extend([] * cuda_graph_pad_size)
|
||||
num_decode_tokens = batch_size - self.num_prefill_tokens
|
||||
block_tables = self._get_graph_runner_block_tables(
|
||||
num_seqs, self.block_tables)
|
||||
else:
|
||||
block_tables = make_tensor_with_pad(
|
||||
self.block_tables,
|
||||
pad=0,
|
||||
dtype=torch.int,
|
||||
device=device,
|
||||
)
|
||||
assert max_query_len > 0, ("query_lens: {}".format(query_lens))
|
||||
|
||||
assert device is not None
|
||||
context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int,
|
||||
device, self.runner.pin_memory)
|
||||
seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
|
||||
self.runner.pin_memory)
|
||||
input_positions = async_tensor_h2d(self.input_positions, torch.long,
|
||||
device, self.runner.pin_memory)
|
||||
slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
|
||||
device, self.runner.pin_memory)
|
||||
query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32,
|
||||
device,
|
||||
self.runner.pin_memory)
|
||||
seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
|
||||
device, self.runner.pin_memory)
|
||||
placeholder_index_maps = {
|
||||
modality: placeholder_map.index_map()
|
||||
for modality, placeholder_map in
|
||||
self.multimodal_placeholder_maps.items()
|
||||
}
|
||||
|
||||
num_kv_splits = 8
|
||||
|
||||
return TritonMLAMetadata(
|
||||
num_prefills=self.num_prefills,
|
||||
slot_mapping=slot_mapping_tensor,
|
||||
num_prefill_tokens=self.num_prefill_tokens,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
seq_lens=seq_lens,
|
||||
multi_modal_placeholder_index_maps=placeholder_index_maps,
|
||||
enable_kv_scales_calculation=True,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_query_len=max_query_len,
|
||||
max_decode_query_len=max_decode_query_len,
|
||||
max_prefill_seq_len=max_prefill_seq_len,
|
||||
max_decode_seq_len=max_decode_seq_len,
|
||||
query_start_loc=query_start_loc_tensor,
|
||||
seq_start_loc=seq_start_loc_tensor,
|
||||
context_lens_tensor=context_lens_tensor,
|
||||
block_tables=block_tables,
|
||||
use_cuda_graph=use_captured_graph,
|
||||
input_positions=input_positions,
|
||||
num_kv_splits=num_kv_splits,
|
||||
head_dim=self.runner.model_config.get_head_size(),
|
||||
)
|
||||
|
||||
|
||||
class TritonMLAImpl(MLACommonImpl):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]],
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
# MLA Specific Arguments
|
||||
**kwargs) -> None:
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
blocksparse_params, logits_soft_cap, attn_type,
|
||||
**kwargs)
|
||||
|
||||
unsupported_features = [
|
||||
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
|
||||
]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"TritonMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, blocksparse_params, "
|
||||
"logits_soft_cap")
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"TritonMLAImpl")
|
||||
|
||||
def _forward_prefill(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
kv_c_normed: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
attn_metadata: TritonMLAMetadata,
|
||||
) -> torch.Tensor:
|
||||
return self._forward_prefill_flash(q, kv_c_normed, k_pe,
|
||||
attn_metadata.seq_start_loc,
|
||||
attn_metadata.max_prefill_seq_len)
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: TritonMLAMetadata,
|
||||
) -> torch.Tensor:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
if self.kv_cache_dtype.startswith("fp8"):
|
||||
raise NotImplementedError("FP8 Triton MLA not yet supported")
|
||||
|
||||
decode_meta = attn_metadata.decode_metadata
|
||||
B = q_nope.shape[0]
|
||||
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||
o = torch.zeros(B,
|
||||
self.num_heads,
|
||||
self.kv_lora_rank,
|
||||
dtype=q.dtype,
|
||||
device=q.device)
|
||||
|
||||
# TODO(lucas) Allocate ahead of time
|
||||
attn_logits = torch.empty(
|
||||
(
|
||||
B,
|
||||
self.num_heads,
|
||||
attn_metadata.num_kv_splits,
|
||||
# NOTE(lucas) idk why the +1 is here but sglang has it so we
|
||||
# just mirror that
|
||||
self.kv_lora_rank + 1,
|
||||
),
|
||||
dtype=torch.float32,
|
||||
device=q.device,
|
||||
)
|
||||
|
||||
# Add a head dim of 1
|
||||
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2)
|
||||
kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank]
|
||||
PAGE_SIZE = kv_c_and_k_pe_cache.size(1)
|
||||
|
||||
# Run MQA
|
||||
decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o,
|
||||
decode_meta.block_tables,
|
||||
decode_meta.seq_lens_tensor, attn_logits,
|
||||
attn_metadata.num_kv_splits, self.scale,
|
||||
PAGE_SIZE)
|
||||
|
||||
return self._v_up_proj_and_o_proj(o)
|
@ -2,7 +2,7 @@
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from itertools import accumulate
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union
|
||||
from typing import (TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union, Optional)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -288,8 +288,11 @@ class CommonAttentionState(AttentionState):
|
||||
self._is_graph_capturing = False
|
||||
|
||||
@contextmanager
|
||||
def graph_capture(self, max_batch_size: int):
|
||||
def graph_capture(self, max_batch_size: int, positions: Optional[torch.Tensor]):
|
||||
assert positions is None
|
||||
|
||||
self._is_graph_capturing = True
|
||||
|
||||
self._graph_slot_mapping = torch.full((max_batch_size, ),
|
||||
PAD_SLOT_ID,
|
||||
dtype=torch.long,
|
||||
@ -299,7 +302,9 @@ class CommonAttentionState(AttentionState):
|
||||
device=self.runner.device)
|
||||
self._graph_block_tables = torch.from_numpy(
|
||||
self.runner.graph_block_tables).to(device=self.runner.device)
|
||||
|
||||
yield
|
||||
|
||||
self._is_graph_capturing = False
|
||||
del self._graph_slot_mapping
|
||||
del self._graph_seq_lens
|
||||
|
@ -41,8 +41,10 @@ class Attention(nn.Module):
|
||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
per_layer_sliding_window: Optional[int] = None,
|
||||
use_mla: bool = False,
|
||||
prefix: str = "",
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
**extra_impl_args,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if per_layer_sliding_window is not None:
|
||||
@ -101,13 +103,18 @@ class Attention(nn.Module):
|
||||
# During model initialization, the default dtype is set as the model
|
||||
# weight and activation dtype.
|
||||
dtype = torch.get_default_dtype()
|
||||
attn_backend = get_attn_backend(head_size, dtype, kv_cache_dtype,
|
||||
block_size, is_attention_free,
|
||||
blocksparse_params is not None)
|
||||
attn_backend = get_attn_backend(head_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
block_size,
|
||||
is_attention_free,
|
||||
blocksparse_params is not None,
|
||||
use_mla=use_mla)
|
||||
impl_cls = attn_backend.get_impl_cls()
|
||||
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
blocksparse_params, logits_soft_cap, attn_type)
|
||||
blocksparse_params, logits_soft_cap, attn_type,
|
||||
**extra_impl_args)
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.num_kv_heads = num_kv_heads
|
||||
@ -193,6 +200,10 @@ class Attention(nn.Module):
|
||||
s += f", backend={self.impl.__class__.__name__}"
|
||||
return s
|
||||
|
||||
def process_weights_after_loading(self):
|
||||
if hasattr(self.impl, "process_weights_after_loading"):
|
||||
self.impl.process_weights_after_loading()
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
"""Multi-headed attention without any cache, used for ViT."""
|
||||
|
667
vllm/attention/ops/triton_decode_attention.py
Normal file
667
vllm/attention/ops/triton_decode_attention.py
Normal file
@ -0,0 +1,667 @@
|
||||
# Adapted from
|
||||
# https://github.com/sgl-project/sglang/blob/9f635ea50de920aa507f486daafba26a5b837574/python/sglang/srt/layers/attention/triton_ops/decode_attention.py
|
||||
# which was originally adapted from
|
||||
# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py
|
||||
# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage2.py
|
||||
|
||||
# Changes:
|
||||
# - Add support for page size >= 1.
|
||||
|
||||
# Copyright 2025 vLLM Team
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
Memory-efficient attention for decoding.
|
||||
It supports page size >= 1.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
is_hip_ = current_platform.is_rocm()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# TODO: Remove this when triton>=3.2.0. This issue will not affect performance
|
||||
# and accuracy.
|
||||
logger.warning(
|
||||
"The following error message 'operation scheduled before its operands' "
|
||||
"can be ignored.")
|
||||
|
||||
|
||||
@triton.jit
|
||||
def tanh(x):
|
||||
# Tanh is just a scaled sigmoid
|
||||
return 2 * tl.sigmoid(2 * x) - 1
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel_stage1(
|
||||
Q,
|
||||
K_Buffer,
|
||||
V_Buffer,
|
||||
sm_scale,
|
||||
Req_to_tokens,
|
||||
B_Seqlen,
|
||||
Att_Out,
|
||||
stride_req_to_tokens_b,
|
||||
stride_qbs,
|
||||
stride_qh,
|
||||
stride_buf_kbs,
|
||||
stride_buf_kh,
|
||||
stride_buf_vbs,
|
||||
stride_buf_vh,
|
||||
stride_mid_ob,
|
||||
stride_mid_oh,
|
||||
stride_mid_os,
|
||||
kv_group_num: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_DV: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
NUM_KV_SPLITS: tl.constexpr,
|
||||
PAGE_SIZE: tl.constexpr,
|
||||
logit_cap: tl.constexpr,
|
||||
Lk: tl.constexpr,
|
||||
Lv: tl.constexpr,
|
||||
):
|
||||
cur_batch = tl.program_id(0)
|
||||
cur_head = tl.program_id(1)
|
||||
split_kv_id = tl.program_id(2)
|
||||
|
||||
cur_kv_head = cur_head // kv_group_num
|
||||
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
offs_dv = tl.arange(0, BLOCK_DV)
|
||||
mask_d = offs_d < Lk
|
||||
mask_dv = offs_dv < Lv
|
||||
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
||||
cur_batch_req_idx = cur_batch
|
||||
|
||||
off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d
|
||||
q = tl.load(Q + off_q, mask=mask_d, other=0.0)
|
||||
|
||||
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
|
||||
split_kv_start = kv_len_per_split * split_kv_id
|
||||
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split,
|
||||
cur_batch_seq_len)
|
||||
|
||||
e_max = -float("inf")
|
||||
e_sum = 0.0
|
||||
acc = tl.zeros([BLOCK_DV], dtype=tl.float32)
|
||||
|
||||
if split_kv_end > split_kv_start:
|
||||
for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
|
||||
offs_n = start_n + tl.arange(0, BLOCK_N)
|
||||
kv_page_number = tl.load(
|
||||
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx +
|
||||
offs_n // PAGE_SIZE,
|
||||
mask=offs_n < split_kv_end,
|
||||
other=0,
|
||||
)
|
||||
kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE
|
||||
offs_buf_k = (kv_loc[:, None] * stride_buf_kbs +
|
||||
cur_kv_head * stride_buf_kh + offs_d[None, :])
|
||||
k = tl.load(
|
||||
K_Buffer + offs_buf_k,
|
||||
mask=(offs_n[:, None] < split_kv_end) & (mask_d[None, :]),
|
||||
other=0.0,
|
||||
)
|
||||
qk = tl.sum(q[None, :] * k, 1)
|
||||
qk *= sm_scale
|
||||
|
||||
if logit_cap > 0:
|
||||
qk = logit_cap * tanh(qk / logit_cap)
|
||||
|
||||
qk = tl.where(offs_n < split_kv_end, qk, float("-inf"))
|
||||
|
||||
offs_buf_v = (kv_loc[:, None] * stride_buf_vbs +
|
||||
cur_kv_head * stride_buf_vh + offs_dv[None, :])
|
||||
v = tl.load(
|
||||
V_Buffer + offs_buf_v,
|
||||
mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]),
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
n_e_max = tl.maximum(tl.max(qk, 0), e_max)
|
||||
re_scale = tl.exp(e_max - n_e_max)
|
||||
p = tl.exp(qk - n_e_max)
|
||||
acc *= re_scale
|
||||
acc += tl.sum(p[:, None] * v, 0)
|
||||
|
||||
e_sum = e_sum * re_scale + tl.sum(p, 0)
|
||||
e_max = n_e_max
|
||||
|
||||
offs_mid_o = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh +
|
||||
split_kv_id * stride_mid_os + offs_dv)
|
||||
|
||||
tl.store(
|
||||
Att_Out + offs_mid_o,
|
||||
acc / e_sum,
|
||||
mask=(mask_dv),
|
||||
)
|
||||
|
||||
offs_mid_o_1 = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh +
|
||||
split_kv_id * stride_mid_os + Lv)
|
||||
|
||||
tl.store(
|
||||
Att_Out + offs_mid_o_1,
|
||||
e_max + tl.log(e_sum),
|
||||
)
|
||||
|
||||
|
||||
def _decode_att_m_fwd(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
att_out,
|
||||
Req_to_tokens,
|
||||
B_Seqlen,
|
||||
num_kv_splits,
|
||||
sm_scale,
|
||||
page_size,
|
||||
logit_cap,
|
||||
):
|
||||
BLOCK = 64
|
||||
NUM_KV_SPLITS = num_kv_splits
|
||||
Lk = k_buffer.shape[-1]
|
||||
Lv = v_buffer.shape[-1]
|
||||
|
||||
batch, head_num = q.shape[0], q.shape[1]
|
||||
|
||||
grid = (batch, head_num, NUM_KV_SPLITS)
|
||||
kv_group_num = q.shape[1] // k_buffer.shape[-2]
|
||||
|
||||
num_warps = 4 if kv_group_num == 1 else 2
|
||||
|
||||
BLOCK_DMODEL = triton.next_power_of_2(Lk)
|
||||
BLOCK_DV = triton.next_power_of_2(Lv)
|
||||
|
||||
_fwd_kernel_stage1[grid](
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
sm_scale,
|
||||
Req_to_tokens,
|
||||
B_Seqlen,
|
||||
att_out,
|
||||
Req_to_tokens.stride(0),
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
k_buffer.stride(-2),
|
||||
k_buffer.stride(-1),
|
||||
v_buffer.stride(-2),
|
||||
v_buffer.stride(-1),
|
||||
att_out.stride(0),
|
||||
att_out.stride(1),
|
||||
att_out.stride(2),
|
||||
kv_group_num=kv_group_num,
|
||||
BLOCK_DMODEL=BLOCK_DMODEL,
|
||||
BLOCK_DV=BLOCK_DV,
|
||||
BLOCK_N=BLOCK,
|
||||
NUM_KV_SPLITS=NUM_KV_SPLITS,
|
||||
PAGE_SIZE=page_size,
|
||||
logit_cap=logit_cap,
|
||||
num_warps=num_warps,
|
||||
num_stages=2,
|
||||
Lk=Lk,
|
||||
Lv=Lv,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_grouped_kernel_stage1(
|
||||
Q,
|
||||
K_Buffer,
|
||||
V_Buffer,
|
||||
sm_scale,
|
||||
Req_to_tokens,
|
||||
B_Seqlen,
|
||||
Att_Out,
|
||||
stride_req_to_tokens_b,
|
||||
stride_qbs,
|
||||
stride_qh,
|
||||
stride_buf_kbs,
|
||||
stride_buf_kh,
|
||||
stride_buf_vbs,
|
||||
stride_buf_vh,
|
||||
stride_mid_ob,
|
||||
stride_mid_oh,
|
||||
stride_mid_os,
|
||||
kv_group_num: tl.constexpr,
|
||||
q_head_num: tl.constexpr,
|
||||
BLOCK_DMODEL: tl.constexpr,
|
||||
BLOCK_DPE: tl.constexpr,
|
||||
BLOCK_DV: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_H: tl.constexpr,
|
||||
NUM_KV_SPLITS: tl.constexpr,
|
||||
PAGE_SIZE: tl.constexpr,
|
||||
logit_cap: tl.constexpr,
|
||||
Lk: tl.constexpr,
|
||||
Lv: tl.constexpr,
|
||||
):
|
||||
cur_batch = tl.program_id(0)
|
||||
cur_head_id = tl.program_id(1)
|
||||
cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H)
|
||||
split_kv_id = tl.program_id(2)
|
||||
|
||||
if kv_group_num > BLOCK_H:
|
||||
VALID_BLOCK_H: tl.constexpr = BLOCK_H
|
||||
else:
|
||||
VALID_BLOCK_H: tl.constexpr = kv_group_num
|
||||
cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H)
|
||||
mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H
|
||||
mask_h = mask_h & (cur_head < q_head_num)
|
||||
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
offs_dv = tl.arange(0, BLOCK_DV)
|
||||
mask_d = offs_d < Lk
|
||||
mask_dv = offs_dv < Lv
|
||||
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
||||
cur_batch_req_idx = cur_batch
|
||||
|
||||
offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[
|
||||
None, :]
|
||||
q = tl.load(Q + offs_q,
|
||||
mask=(mask_h[:, None]) & (mask_d[None, :]),
|
||||
other=0.0)
|
||||
|
||||
if BLOCK_DPE > 0:
|
||||
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
|
||||
mask_dpe = offs_dpe < Lk
|
||||
off_qpe = (cur_batch * stride_qbs + cur_head[:, None] * stride_qh +
|
||||
offs_dpe[None, :])
|
||||
qpe = tl.load(Q + off_qpe,
|
||||
mask=(mask_h[:, None]) & (mask_dpe[None, :]),
|
||||
other=0.0)
|
||||
|
||||
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
|
||||
split_kv_start = kv_len_per_split * split_kv_id
|
||||
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split,
|
||||
cur_batch_seq_len)
|
||||
|
||||
e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf")
|
||||
e_sum = tl.zeros([BLOCK_H], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32)
|
||||
|
||||
if split_kv_end > split_kv_start:
|
||||
for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
|
||||
offs_n = start_n + tl.arange(0, BLOCK_N)
|
||||
kv_page_number = tl.load(
|
||||
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx +
|
||||
offs_n // PAGE_SIZE,
|
||||
mask=offs_n < split_kv_end,
|
||||
other=0,
|
||||
)
|
||||
kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE
|
||||
offs_buf_k = (kv_loc[None, :] * stride_buf_kbs +
|
||||
cur_kv_head * stride_buf_kh + offs_d[:, None])
|
||||
k = tl.load(
|
||||
K_Buffer + offs_buf_k,
|
||||
mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]),
|
||||
other=0.0,
|
||||
)
|
||||
qk = tl.dot(q, k.to(q.dtype))
|
||||
if BLOCK_DPE > 0:
|
||||
offs_buf_kpe = (kv_loc[None, :] * stride_buf_kbs +
|
||||
cur_kv_head * stride_buf_kh +
|
||||
offs_dpe[:, None])
|
||||
kpe = tl.load(
|
||||
K_Buffer + offs_buf_kpe,
|
||||
mask=(offs_n[None, :] < split_kv_end) &
|
||||
(mask_dpe[:, None]),
|
||||
other=0.0,
|
||||
)
|
||||
qk += tl.dot(qpe, kpe.to(qpe.dtype))
|
||||
qk *= sm_scale
|
||||
|
||||
if logit_cap > 0:
|
||||
qk = logit_cap * tanh(qk / logit_cap)
|
||||
|
||||
qk = tl.where(mask_h[:, None] & (offs_n[None, :] < split_kv_end),
|
||||
qk, float("-inf"))
|
||||
|
||||
offs_buf_v = (kv_loc[:, None] * stride_buf_vbs +
|
||||
cur_kv_head * stride_buf_vh + offs_dv[None, :])
|
||||
v = tl.load(
|
||||
V_Buffer + offs_buf_v,
|
||||
mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]),
|
||||
other=0.0,
|
||||
)
|
||||
|
||||
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
|
||||
re_scale = tl.exp(e_max - n_e_max)
|
||||
p = tl.exp(qk - n_e_max[:, None])
|
||||
acc *= re_scale[:, None]
|
||||
acc += tl.dot(p.to(v.dtype), v)
|
||||
|
||||
e_sum = e_sum * re_scale + tl.sum(p, 1)
|
||||
e_max = n_e_max
|
||||
|
||||
offs_mid_o = (cur_batch * stride_mid_ob +
|
||||
cur_head[:, None] * stride_mid_oh +
|
||||
split_kv_id * stride_mid_os + offs_dv[None, :])
|
||||
|
||||
tl.store(
|
||||
Att_Out + offs_mid_o,
|
||||
acc / e_sum[:, None],
|
||||
mask=(mask_h[:, None]) & (mask_dv[None, :]),
|
||||
)
|
||||
|
||||
offs_mid_o_1 = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh +
|
||||
split_kv_id * stride_mid_os + Lv)
|
||||
|
||||
tl.store(
|
||||
Att_Out + offs_mid_o_1,
|
||||
e_max + tl.log(e_sum),
|
||||
mask=mask_h,
|
||||
)
|
||||
|
||||
|
||||
def _decode_grouped_att_m_fwd(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
att_out,
|
||||
Req_to_tokens,
|
||||
B_Seqlen,
|
||||
num_kv_splits,
|
||||
sm_scale,
|
||||
page_size,
|
||||
logit_cap,
|
||||
):
|
||||
BLOCK = 32
|
||||
Lk = k_buffer.shape[-1]
|
||||
Lv = v_buffer.shape[-1]
|
||||
|
||||
# [TODO] work around shmem limit on MI3xx
|
||||
if is_hip_ and Lk >= 576:
|
||||
BLOCK = 16
|
||||
|
||||
if Lk == 576:
|
||||
BLOCK_DMODEL = 512
|
||||
BLOCK_DPE = 64
|
||||
elif Lk == 288:
|
||||
BLOCK_DMODEL = 256
|
||||
BLOCK_DPE = 32
|
||||
else:
|
||||
BLOCK_DMODEL = triton.next_power_of_2(Lk)
|
||||
BLOCK_DPE = 0
|
||||
BLOCK_DV = triton.next_power_of_2(Lv)
|
||||
|
||||
batch, head_num = q.shape[0], q.shape[1]
|
||||
kv_group_num = q.shape[1] // k_buffer.shape[-2]
|
||||
|
||||
BLOCK_H = 16
|
||||
NUM_KV_SPLITS = num_kv_splits
|
||||
grid = (
|
||||
batch,
|
||||
triton.cdiv(head_num, min(BLOCK_H, kv_group_num)),
|
||||
NUM_KV_SPLITS,
|
||||
)
|
||||
|
||||
extra_kargs = {}
|
||||
if is_hip_:
|
||||
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
|
||||
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
|
||||
extra_kargs = {
|
||||
"waves_per_eu": 4,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
}
|
||||
|
||||
_fwd_grouped_kernel_stage1[grid](
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
sm_scale,
|
||||
Req_to_tokens,
|
||||
B_Seqlen,
|
||||
att_out,
|
||||
Req_to_tokens.stride(0),
|
||||
q.stride(0),
|
||||
q.stride(1),
|
||||
k_buffer.stride(-2),
|
||||
k_buffer.stride(-1),
|
||||
v_buffer.stride(-2),
|
||||
v_buffer.stride(-1),
|
||||
att_out.stride(0),
|
||||
att_out.stride(1),
|
||||
att_out.stride(2),
|
||||
kv_group_num=kv_group_num,
|
||||
q_head_num=head_num,
|
||||
BLOCK_DMODEL=BLOCK_DMODEL,
|
||||
BLOCK_DPE=BLOCK_DPE,
|
||||
BLOCK_DV=BLOCK_DV,
|
||||
BLOCK_N=BLOCK,
|
||||
BLOCK_H=BLOCK_H,
|
||||
NUM_KV_SPLITS=NUM_KV_SPLITS,
|
||||
PAGE_SIZE=page_size,
|
||||
logit_cap=logit_cap,
|
||||
num_warps=4,
|
||||
num_stages=2,
|
||||
Lk=Lk,
|
||||
Lv=Lv,
|
||||
**extra_kargs,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel_stage2(
|
||||
Mid_O,
|
||||
o,
|
||||
B_Seqlen,
|
||||
stride_mid_ob,
|
||||
stride_mid_oh,
|
||||
stride_mid_os,
|
||||
stride_obs,
|
||||
stride_oh,
|
||||
NUM_KV_SPLITS: tl.constexpr,
|
||||
BLOCK_DV: tl.constexpr,
|
||||
Lv: tl.constexpr,
|
||||
):
|
||||
cur_batch = tl.program_id(0)
|
||||
cur_head = tl.program_id(1)
|
||||
|
||||
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
||||
|
||||
offs_d = tl.arange(0, BLOCK_DV)
|
||||
mask_d = offs_d < Lv
|
||||
|
||||
e_sum = 0.0
|
||||
e_max = -float("inf")
|
||||
acc = tl.zeros([BLOCK_DV], dtype=tl.float32)
|
||||
|
||||
offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d
|
||||
offs_logic = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + Lv
|
||||
|
||||
for split_kv_id in range(0, NUM_KV_SPLITS):
|
||||
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
|
||||
split_kv_start = kv_len_per_split * split_kv_id
|
||||
split_kv_end = tl.minimum(split_kv_start + kv_len_per_split,
|
||||
cur_batch_seq_len)
|
||||
|
||||
if split_kv_end > split_kv_start:
|
||||
tv = tl.load(Mid_O + offs_v + split_kv_id * stride_mid_os,
|
||||
mask=mask_d,
|
||||
other=0.0)
|
||||
tlogic = tl.load(Mid_O + offs_logic + split_kv_id * stride_mid_os)
|
||||
n_e_max = tl.maximum(tlogic, e_max)
|
||||
|
||||
old_scale = tl.exp(e_max - n_e_max)
|
||||
acc *= old_scale
|
||||
exp_logic = tl.exp(tlogic - n_e_max)
|
||||
acc += exp_logic * tv
|
||||
|
||||
e_sum = e_sum * old_scale + exp_logic
|
||||
e_max = n_e_max
|
||||
|
||||
tl.store(
|
||||
o + cur_batch * stride_obs + cur_head * stride_oh + offs_d,
|
||||
acc / e_sum,
|
||||
mask=mask_d,
|
||||
)
|
||||
|
||||
|
||||
def _decode_softmax_reducev_fwd(
|
||||
logits,
|
||||
q,
|
||||
o,
|
||||
v_buffer,
|
||||
b_seq_len,
|
||||
num_kv_splits,
|
||||
):
|
||||
batch, head_num = q.shape[0], q.shape[1]
|
||||
Lv = v_buffer.shape[-1]
|
||||
BLOCK_DV = triton.next_power_of_2(Lv)
|
||||
|
||||
NUM_KV_SPLITS = num_kv_splits
|
||||
|
||||
extra_kargs = {}
|
||||
if is_hip_:
|
||||
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
|
||||
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
|
||||
extra_kargs = {
|
||||
"waves_per_eu": 4,
|
||||
"matrix_instr_nonkdim": 16,
|
||||
"kpack": 2
|
||||
}
|
||||
|
||||
grid = (batch, head_num)
|
||||
_fwd_kernel_stage2[grid](
|
||||
logits,
|
||||
o,
|
||||
b_seq_len,
|
||||
logits.stride(0),
|
||||
logits.stride(1),
|
||||
logits.stride(2),
|
||||
o.stride(0),
|
||||
o.stride(1),
|
||||
NUM_KV_SPLITS=NUM_KV_SPLITS,
|
||||
BLOCK_DV=BLOCK_DV,
|
||||
Lv=Lv,
|
||||
num_warps=4,
|
||||
num_stages=2,
|
||||
**extra_kargs,
|
||||
)
|
||||
|
||||
|
||||
def decode_attention_fwd_normal(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
o,
|
||||
req_to_token,
|
||||
b_seq_len,
|
||||
attn_logits,
|
||||
num_kv_splits,
|
||||
sm_scale,
|
||||
page_size,
|
||||
logit_cap=0.0,
|
||||
):
|
||||
_decode_att_m_fwd(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
attn_logits,
|
||||
req_to_token,
|
||||
b_seq_len,
|
||||
num_kv_splits,
|
||||
sm_scale,
|
||||
page_size,
|
||||
logit_cap,
|
||||
)
|
||||
_decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len,
|
||||
num_kv_splits)
|
||||
|
||||
|
||||
def decode_attention_fwd_grouped(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
o,
|
||||
req_to_token,
|
||||
b_seq_len,
|
||||
attn_logits,
|
||||
num_kv_splits,
|
||||
sm_scale,
|
||||
page_size,
|
||||
logit_cap=0.0,
|
||||
):
|
||||
_decode_grouped_att_m_fwd(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
attn_logits,
|
||||
req_to_token,
|
||||
b_seq_len,
|
||||
num_kv_splits,
|
||||
sm_scale,
|
||||
page_size,
|
||||
logit_cap,
|
||||
)
|
||||
_decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len,
|
||||
num_kv_splits)
|
||||
|
||||
|
||||
def decode_attention_fwd(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
o,
|
||||
req_to_token,
|
||||
b_seq_len,
|
||||
attn_logits,
|
||||
num_kv_splits,
|
||||
sm_scale,
|
||||
page_size=1,
|
||||
logit_cap=0.0,
|
||||
):
|
||||
assert num_kv_splits == attn_logits.shape[2]
|
||||
kv_group_num = q.shape[1] // v_buffer.shape[-2]
|
||||
|
||||
if kv_group_num == 1:
|
||||
# MHA
|
||||
decode_attention_fwd_normal(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
o,
|
||||
req_to_token,
|
||||
b_seq_len,
|
||||
attn_logits,
|
||||
num_kv_splits,
|
||||
sm_scale,
|
||||
page_size,
|
||||
logit_cap,
|
||||
)
|
||||
else:
|
||||
# GQA/MQA/MLA
|
||||
decode_attention_fwd_grouped(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
o,
|
||||
req_to_token,
|
||||
b_seq_len,
|
||||
attn_logits,
|
||||
num_kv_splits,
|
||||
sm_scale,
|
||||
page_size,
|
||||
logit_cap,
|
||||
)
|
@ -83,6 +83,7 @@ def get_attn_backend(
|
||||
block_size: int,
|
||||
is_attention_free: bool,
|
||||
is_blocksparse: bool = False,
|
||||
use_mla: bool = False,
|
||||
) -> Type[AttentionBackend]:
|
||||
"""Selects which attention backend to use and lazily imports it."""
|
||||
# Accessing envs.* behind an @lru_cache decorator can cause the wrong
|
||||
@ -97,6 +98,7 @@ def get_attn_backend(
|
||||
is_attention_free=is_attention_free,
|
||||
is_blocksparse=is_blocksparse,
|
||||
use_v1=envs.VLLM_USE_V1,
|
||||
use_mla=use_mla,
|
||||
)
|
||||
|
||||
|
||||
@ -109,6 +111,7 @@ def _cached_get_attn_backend(
|
||||
is_attention_free: bool,
|
||||
is_blocksparse: bool = False,
|
||||
use_v1: bool = False,
|
||||
use_mla: bool = False,
|
||||
) -> Type[AttentionBackend]:
|
||||
if is_blocksparse:
|
||||
logger.info("Using BlocksparseFlashAttention backend.")
|
||||
@ -141,7 +144,8 @@ def _cached_get_attn_backend(
|
||||
|
||||
# get device-specific attn_backend
|
||||
attention_cls = current_platform.get_attn_backend_cls(
|
||||
selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1)
|
||||
selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1,
|
||||
use_mla)
|
||||
if not attention_cls:
|
||||
raise ValueError(
|
||||
f"Invalid attention backend for {current_platform.device_name}")
|
||||
|
@ -736,17 +736,24 @@ class ModelConfig:
|
||||
def get_hidden_size(self) -> int:
|
||||
return self.hf_text_config.hidden_size
|
||||
|
||||
@property
|
||||
def is_deepseek_mla(self) -> bool:
|
||||
return hasattr(self.hf_text_config,
|
||||
"model_type") and (self.hf_text_config.model_type
|
||||
in ('deepseek_v2', 'deepseek_v3'))
|
||||
|
||||
def get_head_size(self) -> int:
|
||||
# TODO remove hard code
|
||||
if hasattr(self.hf_text_config,
|
||||
"model_type") and (self.hf_text_config.model_type
|
||||
in ('deepseek_v2', 'deepseek_v3')):
|
||||
qk_rope_head_dim = getattr(self.hf_text_config, "qk_rope_head_dim",
|
||||
0)
|
||||
qk_nope_head_dim = getattr(self.hf_text_config, "qk_nope_head_dim",
|
||||
0)
|
||||
if qk_rope_head_dim and qk_nope_head_dim:
|
||||
return qk_rope_head_dim + qk_nope_head_dim
|
||||
if self.is_deepseek_mla:
|
||||
if self.use_mla:
|
||||
return self.hf_text_config.kv_lora_rank
|
||||
else:
|
||||
qk_rope_head_dim = getattr(self.hf_text_config,
|
||||
"qk_rope_head_dim", 0)
|
||||
qk_nope_head_dim = getattr(self.hf_text_config,
|
||||
"qk_nope_head_dim", 0)
|
||||
if qk_rope_head_dim and qk_nope_head_dim:
|
||||
return qk_rope_head_dim + qk_nope_head_dim
|
||||
|
||||
if self.is_attention_free:
|
||||
return 0
|
||||
@ -805,6 +812,10 @@ class ModelConfig:
|
||||
|
||||
def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
|
||||
"""Returns the number of KV heads per GPU."""
|
||||
if self.use_mla:
|
||||
# When using MLA during decode it becomes MQA
|
||||
return 1
|
||||
|
||||
total_num_kv_heads = self.get_total_num_kv_heads()
|
||||
# If tensor parallelism is used, we divide the number of KV heads by
|
||||
# the tensor parallel size. We will replicate the KV heads in the
|
||||
@ -956,6 +967,10 @@ class ModelConfig:
|
||||
return ModelRegistry.is_cross_encoder_model(architectures)
|
||||
|
||||
@property
|
||||
def use_mla(self) -> bool:
|
||||
use_mla = (self.is_deepseek_mla and not envs.VLLM_MLA_DISABLE)
|
||||
return use_mla
|
||||
|
||||
def supported_runner_types(self) -> Set[RunnerType]:
|
||||
return {_TASK_RUNNER[task] for task in self.supported_tasks}
|
||||
|
||||
|
@ -931,7 +931,6 @@ class EngineArgs:
|
||||
type=str,
|
||||
default="auto",
|
||||
help='The worker class to use for distributed execution.')
|
||||
|
||||
parser.add_argument(
|
||||
"--generation-config",
|
||||
type=nullable_str,
|
||||
|
14
vllm/envs.py
14
vllm/envs.py
@ -77,6 +77,8 @@ if TYPE_CHECKING:
|
||||
V_SCALE_CONSTANT: int = 100
|
||||
VLLM_SERVER_DEV_MODE: bool = False
|
||||
VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: int = 128
|
||||
VLLM_MLA_DISABLE: bool = False
|
||||
VLLM_MLA_PERFORM_MATRIX_ABSORPTION: bool = True
|
||||
|
||||
|
||||
def get_default_cache_root():
|
||||
@ -506,6 +508,18 @@ environment_variables: Dict[str, Callable[[], Any]] = {
|
||||
# TTFT and overall throughput.
|
||||
"VLLM_V1_OUTPUT_PROC_CHUNK_SIZE":
|
||||
lambda: int(os.getenv("VLLM_V1_OUTPUT_PROC_CHUNK_SIZE", "128")),
|
||||
|
||||
# If set, vLLM will disable the MLA attention optimizations.
|
||||
"VLLM_MLA_DISABLE":
|
||||
lambda: bool(int(os.getenv("VLLM_MLA_DISABLE", "0"))),
|
||||
|
||||
# Flag that can control whether or not we perform matrix-absorption for MLA
|
||||
# decode, i.e. absorb W_UK into W_Q/W_UK and W_UV into W_O, absorbing the
|
||||
# matrices reduces the runtime FLOPs needed to compute MLA but requires
|
||||
# storing more weights, W_Q_UK and W_UV_O, so can increase memory usage,
|
||||
# the is enabled by default
|
||||
"VLLM_MLA_PERFORM_MATRIX_ABSORPTION":
|
||||
lambda: bool(int(os.getenv("VLLM_MLA_PERFORM_MATRIX_ABSORPTION", "1")))
|
||||
}
|
||||
|
||||
# end-env-vars-definition
|
||||
|
@ -23,6 +23,7 @@ from torch import nn
|
||||
from transformers import AutoModelForCausalLM
|
||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
||||
|
||||
from vllm.attention import Attention
|
||||
from vllm.config import (LoadConfig, LoadFormat, ModelConfig, ParallelConfig,
|
||||
VllmConfig, set_current_vllm_config)
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
@ -397,6 +398,11 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
# parameters onto device for processing and back off after.
|
||||
with device_loading_context(module, target_device):
|
||||
quant_method.process_weights_after_loading(module)
|
||||
elif isinstance(module, Attention) and \
|
||||
hasattr(module, "process_weights_after_loading"):
|
||||
# When attention modules need to process weights after
|
||||
# currently only used by MLA
|
||||
module.process_weights_after_loading()
|
||||
return model.eval()
|
||||
|
||||
|
||||
|
@ -28,7 +28,7 @@ from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
from vllm.distributed import (get_pp_group,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
@ -326,12 +326,156 @@ class DeepseekV2Attention(nn.Module):
|
||||
return output
|
||||
|
||||
|
||||
class DeepseekV2MLAAttention(nn.Module):
|
||||
"""
|
||||
Main reference: DeepseekV2 paper, and FlashInfer Implementation
|
||||
(https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).
|
||||
|
||||
For more info see MLACommonImpl in: vllm/attention/backends/mla/utils.py
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
qk_nope_head_dim: int,
|
||||
qk_rope_head_dim: int,
|
||||
v_head_dim: int,
|
||||
q_lora_rank: Optional[int],
|
||||
kv_lora_rank: int,
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
max_position_embeddings: int = 8192,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.qk_nope_head_dim = qk_nope_head_dim
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
|
||||
self.num_heads = num_heads
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
assert num_heads % tp_size == 0
|
||||
self.num_local_heads = num_heads // tp_size
|
||||
|
||||
self.scaling = self.qk_head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
if self.q_lora_rank is not None:
|
||||
self.q_a_proj = ReplicatedLinear(self.hidden_size,
|
||||
self.q_lora_rank,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.q_a_proj")
|
||||
self.q_a_layernorm = RMSNorm(self.q_lora_rank,
|
||||
eps=config.rms_norm_eps)
|
||||
self.q_b_proj = ColumnParallelLinear(q_lora_rank,
|
||||
self.num_heads *
|
||||
self.qk_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.q_b_proj")
|
||||
else:
|
||||
self.q_proj = ColumnParallelLinear(self.hidden_size,
|
||||
self.num_heads *
|
||||
self.qk_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.q_proj")
|
||||
|
||||
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
||||
self.hidden_size,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.kv_a_proj_with_mqa")
|
||||
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
|
||||
eps=config.rms_norm_eps)
|
||||
self.kv_b_proj = ColumnParallelLinear(
|
||||
self.kv_lora_rank,
|
||||
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.kv_b_proj")
|
||||
self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim,
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj")
|
||||
|
||||
rope_scaling["rope_type"] = 'deepseek_yarn'
|
||||
self.rotary_emb = get_rope(qk_rope_head_dim,
|
||||
rotary_dim=qk_rope_head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
base=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
is_neox_style=False)
|
||||
if rope_scaling:
|
||||
mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
||||
self.scaling = self.scaling * mscale * mscale
|
||||
|
||||
self.mla_attn = Attention(
|
||||
num_heads=self.num_local_heads,
|
||||
head_size=self.kv_lora_rank,
|
||||
scale=self.scaling,
|
||||
num_kv_heads=1,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
use_mla=True,
|
||||
# MLA Args
|
||||
q_lora_rank=self.q_lora_rank,
|
||||
kv_lora_rank=self.kv_lora_rank,
|
||||
qk_nope_head_dim=self.qk_nope_head_dim,
|
||||
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||
qk_head_dim=self.qk_head_dim,
|
||||
v_head_dim=self.v_head_dim,
|
||||
rotary_emb=self.rotary_emb,
|
||||
q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj,
|
||||
kv_b_proj=self.kv_b_proj,
|
||||
o_proj=self.o_proj,
|
||||
)
|
||||
|
||||
self.prefix = prefix
|
||||
self.debug_layer_idx = int(self.prefix.split(".")[-2])
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
if self.q_lora_rank is not None:
|
||||
ckq = self.q_a_proj(hidden_states)[0]
|
||||
hidden_states_or_q_c = self.q_a_layernorm(ckq)
|
||||
else:
|
||||
hidden_states_or_q_c = hidden_states
|
||||
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
|
||||
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
|
||||
return self.mla_attn(hidden_states_or_q_c, kv_c_normed, k_pe, kv_cache,
|
||||
attn_metadata)
|
||||
|
||||
|
||||
class DeepseekV2DecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
prefix: str,
|
||||
model_config: ModelConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
@ -344,7 +488,11 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
# DecoderLayers are created with `make_layers` which passes the prefix
|
||||
# with the layer's index.
|
||||
layer_idx = int(prefix.split(sep='.')[-1])
|
||||
self.self_attn = DeepseekV2Attention(
|
||||
if model_config.use_mla:
|
||||
attn_cls = DeepseekV2MLAAttention
|
||||
else:
|
||||
attn_cls = DeepseekV2Attention
|
||||
self.self_attn = attn_cls(
|
||||
config=config,
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
@ -421,6 +569,7 @@ class DeepseekV2Model(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
model_config = vllm_config.model_config
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
@ -440,6 +589,7 @@ class DeepseekV2Model(nn.Module):
|
||||
lambda prefix: DeepseekV2DecoderLayer(
|
||||
config,
|
||||
prefix,
|
||||
model_config=model_config,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
),
|
||||
|
@ -31,7 +31,8 @@ class CpuPlatform(Platform):
|
||||
@classmethod
|
||||
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
|
||||
dtype: torch.dtype, kv_cache_dtype: Optional[str],
|
||||
block_size: int, use_v1: bool) -> str:
|
||||
block_size: int, use_v1: bool,
|
||||
use_mla: bool) -> str:
|
||||
if selected_backend != _Backend.TORCH_SDPA:
|
||||
logger.info("Cannot use %s backend on CPU.", selected_backend)
|
||||
logger.info("Using Torch SDPA backend.")
|
||||
|
@ -157,10 +157,14 @@ class CudaPlatformBase(Platform):
|
||||
|
||||
@classmethod
|
||||
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
|
||||
kv_cache_dtype, block_size, use_v1) -> str:
|
||||
kv_cache_dtype, block_size, use_v1,
|
||||
use_mla) -> str:
|
||||
if use_v1:
|
||||
logger.info("Using Flash Attention backend on V1 engine.")
|
||||
return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
|
||||
if use_mla:
|
||||
logger.info("Using Triton MLA backend.")
|
||||
return "vllm.attention.backends.triton_mla.TritonMLABackend"
|
||||
if selected_backend == _Backend.FLASHINFER:
|
||||
logger.info("Using FlashInfer backend.")
|
||||
return "vllm.attention.backends.flashinfer.FlashInferBackend"
|
||||
@ -171,7 +175,8 @@ class CudaPlatformBase(Platform):
|
||||
pass
|
||||
elif selected_backend:
|
||||
raise ValueError(
|
||||
f"Invalid attention backend for {cls.device_name}")
|
||||
f"Invalid attention backend for {cls.device_name}, "
|
||||
f"with use_v1: {use_v1} use_mla: {use_mla}")
|
||||
|
||||
target_backend = _Backend.FLASH_ATTN
|
||||
if not cls.has_device_capability(80):
|
||||
|
@ -27,7 +27,8 @@ class HpuPlatform(Platform):
|
||||
@classmethod
|
||||
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
|
||||
dtype: torch.dtype, kv_cache_dtype: Optional[str],
|
||||
block_size: int, use_v1: bool) -> str:
|
||||
block_size: int, use_v1: bool,
|
||||
use_mla: bool) -> str:
|
||||
logger.info("Using HPUAttention backend.")
|
||||
return "vllm.attention.backends.hpu_attn.HPUAttentionBackend"
|
||||
|
||||
|
@ -30,6 +30,7 @@ class _Backend(enum.Enum):
|
||||
TORCH_SDPA = enum.auto()
|
||||
OPENVINO = enum.auto()
|
||||
FLASHINFER = enum.auto()
|
||||
TRITON_MLA = enum.auto()
|
||||
HPU_ATTN = enum.auto()
|
||||
PALLAS = enum.auto()
|
||||
IPEX = enum.auto()
|
||||
@ -139,7 +140,8 @@ class Platform:
|
||||
@classmethod
|
||||
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
|
||||
dtype: torch.dtype, kv_cache_dtype: Optional[str],
|
||||
block_size: int, use_v1: bool) -> str:
|
||||
block_size: int, use_v1: bool,
|
||||
use_mla: bool) -> str:
|
||||
"""Get the attention backend class of a device."""
|
||||
return ""
|
||||
|
||||
|
@ -30,7 +30,8 @@ class OpenVinoPlatform(Platform):
|
||||
@classmethod
|
||||
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
|
||||
dtype: torch.dtype, kv_cache_dtype: Optional[str],
|
||||
block_size: int, use_v1: bool) -> str:
|
||||
block_size: int, use_v1: bool,
|
||||
use_mla: bool) -> str:
|
||||
if selected_backend != _Backend.OPENVINO:
|
||||
logger.info("Cannot use %s backend on OpenVINO.", selected_backend)
|
||||
logger.info("Using OpenVINO Attention backend.")
|
||||
|
@ -75,7 +75,8 @@ class RocmPlatform(Platform):
|
||||
|
||||
@classmethod
|
||||
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
|
||||
kv_cache_dtype, block_size, use_v1) -> str:
|
||||
kv_cache_dtype, block_size, use_v1,
|
||||
use_mla) -> str:
|
||||
selected_backend = (_Backend.ROCM_FLASH if selected_backend
|
||||
== _Backend.FLASH_ATTN else selected_backend)
|
||||
if selected_backend == _Backend.ROCM_FLASH:
|
||||
|
@ -29,7 +29,8 @@ class TpuPlatform(Platform):
|
||||
@classmethod
|
||||
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
|
||||
dtype: torch.dtype, kv_cache_dtype: Optional[str],
|
||||
block_size: int, use_v1: bool) -> str:
|
||||
block_size: int, use_v1: bool,
|
||||
use_mla: bool) -> str:
|
||||
if selected_backend != _Backend.PALLAS:
|
||||
logger.info("Cannot use %s backend on TPU.", selected_backend)
|
||||
logger.info("Using Pallas backend.")
|
||||
|
@ -27,7 +27,8 @@ class XPUPlatform(Platform):
|
||||
@classmethod
|
||||
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
|
||||
dtype: torch.dtype, kv_cache_dtype: Optional[str],
|
||||
block_size: int, use_v1: bool) -> str:
|
||||
block_size: int, use_v1: bool,
|
||||
use_mla: bool) -> str:
|
||||
if selected_backend != _Backend.IPEX:
|
||||
logger.info("Cannot use %s backend on XPU.", selected_backend)
|
||||
logger.info("Using IPEX attention backend.")
|
||||
|
@ -56,7 +56,8 @@ class CacheEngine:
|
||||
model_config.dtype,
|
||||
cache_config.cache_dtype,
|
||||
self.block_size,
|
||||
model_config.is_attention_free)
|
||||
model_config.is_attention_free,
|
||||
use_mla=model_config.use_mla)
|
||||
|
||||
# Initialize the cache.
|
||||
self.gpu_cache = self._allocate_kv_cache(
|
||||
|
@ -1066,6 +1066,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
self.kv_cache_dtype,
|
||||
self.block_size,
|
||||
self.model_config.is_attention_free,
|
||||
use_mla=self.model_config.use_mla,
|
||||
) if needs_attn_backend else None
|
||||
if self.attn_backend:
|
||||
self.attn_state = self.attn_backend.get_state_cls()(
|
||||
@ -1467,8 +1468,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
dtype=self.model_config.dtype,
|
||||
device=self.device)
|
||||
|
||||
with self.attn_state.graph_capture(max_batch_size), graph_capture(
|
||||
self.device) as graph_capture_context:
|
||||
with self.attn_state.graph_capture(
|
||||
max_batch_size, input_positions), graph_capture(
|
||||
self.device) as graph_capture_context:
|
||||
# NOTE: Capturing the largest batch size first may help reduce the
|
||||
# memory usage of CUDA graph.
|
||||
for virtual_engine in range(
|
||||
@ -1973,7 +1975,8 @@ class CUDAGraphRunner(nn.Module):
|
||||
|
||||
# Copy the input tensors to the input buffers.
|
||||
self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
|
||||
self.input_buffers["positions"].copy_(positions, non_blocking=True)
|
||||
if positions is not None:
|
||||
self.input_buffers["positions"].copy_(positions, non_blocking=True)
|
||||
|
||||
if self.backend_name != "NO_ATTENTION":
|
||||
self.input_buffers["slot_mapping"].copy_(
|
||||
|
Reference in New Issue
Block a user