mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
[Performance] Remove input pads in cutlass_mla and optimize v_proj output handling (#25184)
Signed-off-by: Alexander Matveev <amatveev@redhat.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
committed by
yewentao256
parent
25dd155e60
commit
dbb029cfe1
@ -942,6 +942,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
qk_head_dim: int,
|
||||
v_head_dim: int,
|
||||
kv_b_proj: ColumnParallelLinear,
|
||||
q_pad_num_heads: Optional[int] = None,
|
||||
) -> None:
|
||||
if kv_sharing_target_layer_name is not None:
|
||||
raise NotImplementedError("KV sharing is not supported for MLA")
|
||||
@ -959,6 +960,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
self.qk_head_dim = qk_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
self.kv_b_proj = kv_b_proj
|
||||
self.q_pad_num_heads = q_pad_num_heads
|
||||
|
||||
if use_flashinfer_prefill():
|
||||
logger.debug_once("Using FlashInfer prefill for MLA")
|
||||
@ -1134,7 +1136,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
True, #Indicates actual_seq_lens are on GPU or CPU.
|
||||
)
|
||||
|
||||
def _v_up_proj(self, x):
|
||||
def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor):
|
||||
# Convert from (B, N, L) to (N, B, L)
|
||||
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
|
||||
if is_rocm_aiter_fp8bmm_enabled():
|
||||
@ -1146,12 +1148,23 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
transpose_bm=True)
|
||||
# Convert from (B, N, V) to (B, N * V)
|
||||
x = x.reshape(-1, self.num_heads * self.v_head_dim)
|
||||
# Copy result
|
||||
out.copy_(x)
|
||||
else:
|
||||
# Convert from (B, N * V) to (N, B, V)
|
||||
out = out.view(-1, self.num_heads, self.v_head_dim).transpose(0, 1)
|
||||
|
||||
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
|
||||
x = torch.bmm(x, self.W_UV)
|
||||
torch.bmm(x, self.W_UV, out=out) # Reuse "out" to make it "hot"
|
||||
|
||||
# Convert from (N, B, V) to (B, N * V)
|
||||
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
|
||||
return x
|
||||
out_new = out.transpose(0, 1).reshape(
|
||||
-1, self.num_heads * self.v_head_dim)
|
||||
|
||||
# Adjust output buffer shape back to the original (B, N * V)
|
||||
N, B, V = out.shape
|
||||
out.resize_((B, N * V))
|
||||
out.copy_(out_new) # Copy result
|
||||
|
||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||
|
||||
@ -1559,6 +1572,15 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
# Convert from (B, N, P) to (N, B, P)
|
||||
decode_q_nope = decode_q_nope.transpose(0, 1)
|
||||
|
||||
# Pads the head_dim if necessary (for the underlying kernel)
|
||||
if self.q_pad_num_heads is not None:
|
||||
B, N, L = decode_q_pe.shape
|
||||
decode_pe_padded = decode_q_pe.new_empty(
|
||||
(B, self.q_pad_num_heads, L))
|
||||
decode_pe_padded.resize_((B, N, L))
|
||||
decode_pe_padded.copy_(decode_q_pe)
|
||||
decode_q_pe = decode_pe_padded
|
||||
|
||||
if is_rocm_aiter_fp8bmm_enabled():
|
||||
# Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L)
|
||||
decode_ql_nope = aiter_triton_fp8_bmm(decode_q_nope,
|
||||
@ -1567,8 +1589,19 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
group_size=128,
|
||||
transpose_bm=True)
|
||||
else:
|
||||
# Pads the head_dim if necessary (for the underlying kernel)
|
||||
N, B, P = decode_q_nope.shape
|
||||
_, _, L = self.W_UK_T.shape
|
||||
if self.q_pad_num_heads is not None:
|
||||
decode_ql_nope = decode_q_nope.new_empty(
|
||||
(self.q_pad_num_heads, B, L))
|
||||
decode_ql_nope.resize_((N, B, L))
|
||||
|
||||
else:
|
||||
decode_ql_nope = decode_q_nope.new_empty((N, B, L))
|
||||
|
||||
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
|
||||
decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T)
|
||||
torch.bmm(decode_q_nope, self.W_UK_T, out=decode_ql_nope)
|
||||
# Convert from (N, B, L) to (B, N, L)
|
||||
decode_ql_nope = decode_ql_nope.transpose(0, 1)
|
||||
|
||||
@ -1603,5 +1636,5 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
||||
attn_out = cp_lse_ag_out_rs(attn_out, lse, get_dcp_group())
|
||||
|
||||
# v_up projection
|
||||
output[:num_decode_tokens] = self._v_up_proj(attn_out)
|
||||
self._v_up_proj(attn_out, out=output[:num_decode_tokens])
|
||||
return output_padded
|
||||
|
@ -74,6 +74,8 @@ class SM100Workspace:
|
||||
|
||||
g_sm100_workspace = SM100Workspace(128 * 1024 * 1024) # 128MB
|
||||
|
||||
MAX_HEADS = 128
|
||||
|
||||
|
||||
class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
can_return_lse_for_decode: bool = True
|
||||
@ -92,10 +94,18 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
kv_sharing_target_layer_name: Optional[str],
|
||||
# MLA Specific Arguments
|
||||
**mla_args) -> None:
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
logits_soft_cap, attn_type,
|
||||
kv_sharing_target_layer_name, **mla_args)
|
||||
super().__init__(num_heads,
|
||||
head_size,
|
||||
scale,
|
||||
num_kv_heads,
|
||||
alibi_slopes,
|
||||
sliding_window,
|
||||
kv_cache_dtype,
|
||||
logits_soft_cap,
|
||||
attn_type,
|
||||
kv_sharing_target_layer_name,
|
||||
q_pad_num_heads=MAX_HEADS,
|
||||
**mla_args)
|
||||
|
||||
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
|
||||
if any(unsupported_features):
|
||||
@ -157,14 +167,6 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
|
||||
MAX_HEADS = 128
|
||||
assert H <= MAX_HEADS, f"H must be <= {MAX_HEADS}, but got {H}"
|
||||
if H < MAX_HEADS:
|
||||
q_nope_padded = q_nope.new_empty((B_q, MAX_HEADS, D_q_nope))
|
||||
q_nope_padded[:, :H] = q_nope
|
||||
q_nope = q_nope_padded
|
||||
|
||||
q_pe_padded = q_pe.new_empty((B_q, MAX_HEADS, D_q_pe))
|
||||
q_pe_padded[:, :H] = q_pe
|
||||
q_pe = q_pe_padded
|
||||
|
||||
assert len(page_table.shape) == 2
|
||||
B_block_table, block_num = page_table.shape
|
||||
@ -206,9 +208,9 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||
)
|
||||
|
||||
if H < MAX_HEADS:
|
||||
# Extract the subsets of the outputs
|
||||
lse = lse[:, :H] if self.need_to_return_lse_for_decode else lse
|
||||
out = out[:, :H]
|
||||
if self.need_to_return_lse_for_decode:
|
||||
lse = lse[:, :H].contiguous()
|
||||
|
||||
return out, lse
|
||||
|
||||
|
Reference in New Issue
Block a user