mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
[Performance] Remove input pads in cutlass_mla and optimize v_proj output handling (#25184)
Signed-off-by: Alexander Matveev <amatveev@redhat.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
committed by
yewentao256
parent
25dd155e60
commit
dbb029cfe1
@ -942,6 +942,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
|||||||
qk_head_dim: int,
|
qk_head_dim: int,
|
||||||
v_head_dim: int,
|
v_head_dim: int,
|
||||||
kv_b_proj: ColumnParallelLinear,
|
kv_b_proj: ColumnParallelLinear,
|
||||||
|
q_pad_num_heads: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
if kv_sharing_target_layer_name is not None:
|
if kv_sharing_target_layer_name is not None:
|
||||||
raise NotImplementedError("KV sharing is not supported for MLA")
|
raise NotImplementedError("KV sharing is not supported for MLA")
|
||||||
@ -959,6 +960,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
|||||||
self.qk_head_dim = qk_head_dim
|
self.qk_head_dim = qk_head_dim
|
||||||
self.v_head_dim = v_head_dim
|
self.v_head_dim = v_head_dim
|
||||||
self.kv_b_proj = kv_b_proj
|
self.kv_b_proj = kv_b_proj
|
||||||
|
self.q_pad_num_heads = q_pad_num_heads
|
||||||
|
|
||||||
if use_flashinfer_prefill():
|
if use_flashinfer_prefill():
|
||||||
logger.debug_once("Using FlashInfer prefill for MLA")
|
logger.debug_once("Using FlashInfer prefill for MLA")
|
||||||
@ -1134,7 +1136,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
|||||||
True, #Indicates actual_seq_lens are on GPU or CPU.
|
True, #Indicates actual_seq_lens are on GPU or CPU.
|
||||||
)
|
)
|
||||||
|
|
||||||
def _v_up_proj(self, x):
|
def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor):
|
||||||
# Convert from (B, N, L) to (N, B, L)
|
# Convert from (B, N, L) to (N, B, L)
|
||||||
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
|
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
|
||||||
if is_rocm_aiter_fp8bmm_enabled():
|
if is_rocm_aiter_fp8bmm_enabled():
|
||||||
@ -1146,12 +1148,23 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
|||||||
transpose_bm=True)
|
transpose_bm=True)
|
||||||
# Convert from (B, N, V) to (B, N * V)
|
# Convert from (B, N, V) to (B, N * V)
|
||||||
x = x.reshape(-1, self.num_heads * self.v_head_dim)
|
x = x.reshape(-1, self.num_heads * self.v_head_dim)
|
||||||
|
# Copy result
|
||||||
|
out.copy_(x)
|
||||||
else:
|
else:
|
||||||
|
# Convert from (B, N * V) to (N, B, V)
|
||||||
|
out = out.view(-1, self.num_heads, self.v_head_dim).transpose(0, 1)
|
||||||
|
|
||||||
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
|
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
|
||||||
x = torch.bmm(x, self.W_UV)
|
torch.bmm(x, self.W_UV, out=out) # Reuse "out" to make it "hot"
|
||||||
|
|
||||||
# Convert from (N, B, V) to (B, N * V)
|
# Convert from (N, B, V) to (B, N * V)
|
||||||
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
|
out_new = out.transpose(0, 1).reshape(
|
||||||
return x
|
-1, self.num_heads * self.v_head_dim)
|
||||||
|
|
||||||
|
# Adjust output buffer shape back to the original (B, N * V)
|
||||||
|
N, B, V = out.shape
|
||||||
|
out.resize_((B, N * V))
|
||||||
|
out.copy_(out_new) # Copy result
|
||||||
|
|
||||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||||
|
|
||||||
@ -1559,6 +1572,15 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
|||||||
# Convert from (B, N, P) to (N, B, P)
|
# Convert from (B, N, P) to (N, B, P)
|
||||||
decode_q_nope = decode_q_nope.transpose(0, 1)
|
decode_q_nope = decode_q_nope.transpose(0, 1)
|
||||||
|
|
||||||
|
# Pads the head_dim if necessary (for the underlying kernel)
|
||||||
|
if self.q_pad_num_heads is not None:
|
||||||
|
B, N, L = decode_q_pe.shape
|
||||||
|
decode_pe_padded = decode_q_pe.new_empty(
|
||||||
|
(B, self.q_pad_num_heads, L))
|
||||||
|
decode_pe_padded.resize_((B, N, L))
|
||||||
|
decode_pe_padded.copy_(decode_q_pe)
|
||||||
|
decode_q_pe = decode_pe_padded
|
||||||
|
|
||||||
if is_rocm_aiter_fp8bmm_enabled():
|
if is_rocm_aiter_fp8bmm_enabled():
|
||||||
# Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L)
|
# Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L)
|
||||||
decode_ql_nope = aiter_triton_fp8_bmm(decode_q_nope,
|
decode_ql_nope = aiter_triton_fp8_bmm(decode_q_nope,
|
||||||
@ -1567,8 +1589,19 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
|||||||
group_size=128,
|
group_size=128,
|
||||||
transpose_bm=True)
|
transpose_bm=True)
|
||||||
else:
|
else:
|
||||||
|
# Pads the head_dim if necessary (for the underlying kernel)
|
||||||
|
N, B, P = decode_q_nope.shape
|
||||||
|
_, _, L = self.W_UK_T.shape
|
||||||
|
if self.q_pad_num_heads is not None:
|
||||||
|
decode_ql_nope = decode_q_nope.new_empty(
|
||||||
|
(self.q_pad_num_heads, B, L))
|
||||||
|
decode_ql_nope.resize_((N, B, L))
|
||||||
|
|
||||||
|
else:
|
||||||
|
decode_ql_nope = decode_q_nope.new_empty((N, B, L))
|
||||||
|
|
||||||
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
|
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
|
||||||
decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T)
|
torch.bmm(decode_q_nope, self.W_UK_T, out=decode_ql_nope)
|
||||||
# Convert from (N, B, L) to (B, N, L)
|
# Convert from (N, B, L) to (B, N, L)
|
||||||
decode_ql_nope = decode_ql_nope.transpose(0, 1)
|
decode_ql_nope = decode_ql_nope.transpose(0, 1)
|
||||||
|
|
||||||
@ -1603,5 +1636,5 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
|
|||||||
attn_out = cp_lse_ag_out_rs(attn_out, lse, get_dcp_group())
|
attn_out = cp_lse_ag_out_rs(attn_out, lse, get_dcp_group())
|
||||||
|
|
||||||
# v_up projection
|
# v_up projection
|
||||||
output[:num_decode_tokens] = self._v_up_proj(attn_out)
|
self._v_up_proj(attn_out, out=output[:num_decode_tokens])
|
||||||
return output_padded
|
return output_padded
|
||||||
|
@ -74,6 +74,8 @@ class SM100Workspace:
|
|||||||
|
|
||||||
g_sm100_workspace = SM100Workspace(128 * 1024 * 1024) # 128MB
|
g_sm100_workspace = SM100Workspace(128 * 1024 * 1024) # 128MB
|
||||||
|
|
||||||
|
MAX_HEADS = 128
|
||||||
|
|
||||||
|
|
||||||
class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
||||||
can_return_lse_for_decode: bool = True
|
can_return_lse_for_decode: bool = True
|
||||||
@ -92,10 +94,18 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
|||||||
kv_sharing_target_layer_name: Optional[str],
|
kv_sharing_target_layer_name: Optional[str],
|
||||||
# MLA Specific Arguments
|
# MLA Specific Arguments
|
||||||
**mla_args) -> None:
|
**mla_args) -> None:
|
||||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
super().__init__(num_heads,
|
||||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
head_size,
|
||||||
logits_soft_cap, attn_type,
|
scale,
|
||||||
kv_sharing_target_layer_name, **mla_args)
|
num_kv_heads,
|
||||||
|
alibi_slopes,
|
||||||
|
sliding_window,
|
||||||
|
kv_cache_dtype,
|
||||||
|
logits_soft_cap,
|
||||||
|
attn_type,
|
||||||
|
kv_sharing_target_layer_name,
|
||||||
|
q_pad_num_heads=MAX_HEADS,
|
||||||
|
**mla_args)
|
||||||
|
|
||||||
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
|
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
|
||||||
if any(unsupported_features):
|
if any(unsupported_features):
|
||||||
@ -157,14 +167,6 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
|||||||
|
|
||||||
MAX_HEADS = 128
|
MAX_HEADS = 128
|
||||||
assert H <= MAX_HEADS, f"H must be <= {MAX_HEADS}, but got {H}"
|
assert H <= MAX_HEADS, f"H must be <= {MAX_HEADS}, but got {H}"
|
||||||
if H < MAX_HEADS:
|
|
||||||
q_nope_padded = q_nope.new_empty((B_q, MAX_HEADS, D_q_nope))
|
|
||||||
q_nope_padded[:, :H] = q_nope
|
|
||||||
q_nope = q_nope_padded
|
|
||||||
|
|
||||||
q_pe_padded = q_pe.new_empty((B_q, MAX_HEADS, D_q_pe))
|
|
||||||
q_pe_padded[:, :H] = q_pe
|
|
||||||
q_pe = q_pe_padded
|
|
||||||
|
|
||||||
assert len(page_table.shape) == 2
|
assert len(page_table.shape) == 2
|
||||||
B_block_table, block_num = page_table.shape
|
B_block_table, block_num = page_table.shape
|
||||||
@ -206,9 +208,9 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if H < MAX_HEADS:
|
if H < MAX_HEADS:
|
||||||
|
# Extract the subsets of the outputs
|
||||||
|
lse = lse[:, :H] if self.need_to_return_lse_for_decode else lse
|
||||||
out = out[:, :H]
|
out = out[:, :H]
|
||||||
if self.need_to_return_lse_for_decode:
|
|
||||||
lse = lse[:, :H].contiguous()
|
|
||||||
|
|
||||||
return out, lse
|
return out, lse
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user