[Performance] Remove input pads in cutlass_mla and optimize v_proj output handling (#25184)

Signed-off-by: Alexander Matveev <amatveev@redhat.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Alexander Matveev
2025-09-22 21:20:53 -04:00
committed by yewentao256
parent 25dd155e60
commit dbb029cfe1
2 changed files with 55 additions and 20 deletions

View File

@ -942,6 +942,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
qk_head_dim: int, qk_head_dim: int,
v_head_dim: int, v_head_dim: int,
kv_b_proj: ColumnParallelLinear, kv_b_proj: ColumnParallelLinear,
q_pad_num_heads: Optional[int] = None,
) -> None: ) -> None:
if kv_sharing_target_layer_name is not None: if kv_sharing_target_layer_name is not None:
raise NotImplementedError("KV sharing is not supported for MLA") raise NotImplementedError("KV sharing is not supported for MLA")
@ -959,6 +960,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
self.qk_head_dim = qk_head_dim self.qk_head_dim = qk_head_dim
self.v_head_dim = v_head_dim self.v_head_dim = v_head_dim
self.kv_b_proj = kv_b_proj self.kv_b_proj = kv_b_proj
self.q_pad_num_heads = q_pad_num_heads
if use_flashinfer_prefill(): if use_flashinfer_prefill():
logger.debug_once("Using FlashInfer prefill for MLA") logger.debug_once("Using FlashInfer prefill for MLA")
@ -1134,7 +1136,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
True, #Indicates actual_seq_lens are on GPU or CPU. True, #Indicates actual_seq_lens are on GPU or CPU.
) )
def _v_up_proj(self, x): def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor):
# Convert from (B, N, L) to (N, B, L) # Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
if is_rocm_aiter_fp8bmm_enabled(): if is_rocm_aiter_fp8bmm_enabled():
@ -1146,12 +1148,23 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
transpose_bm=True) transpose_bm=True)
# Convert from (B, N, V) to (B, N * V) # Convert from (B, N, V) to (B, N * V)
x = x.reshape(-1, self.num_heads * self.v_head_dim) x = x.reshape(-1, self.num_heads * self.v_head_dim)
# Copy result
out.copy_(x)
else: else:
# Convert from (B, N * V) to (N, B, V)
out = out.view(-1, self.num_heads, self.v_head_dim).transpose(0, 1)
# Multiply (N, B, L) x (N, L, V) -> (N, B, V) # Multiply (N, B, L) x (N, L, V) -> (N, B, V)
x = torch.bmm(x, self.W_UV) torch.bmm(x, self.W_UV, out=out) # Reuse "out" to make it "hot"
# Convert from (N, B, V) to (B, N * V) # Convert from (N, B, V) to (B, N * V)
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) out_new = out.transpose(0, 1).reshape(
return x -1, self.num_heads * self.v_head_dim)
# Adjust output buffer shape back to the original (B, N * V)
N, B, V = out.shape
out.resize_((B, N * V))
out.copy_(out_new) # Copy result
def process_weights_after_loading(self, act_dtype: torch.dtype): def process_weights_after_loading(self, act_dtype: torch.dtype):
@ -1559,6 +1572,15 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
# Convert from (B, N, P) to (N, B, P) # Convert from (B, N, P) to (N, B, P)
decode_q_nope = decode_q_nope.transpose(0, 1) decode_q_nope = decode_q_nope.transpose(0, 1)
# Pads the head_dim if necessary (for the underlying kernel)
if self.q_pad_num_heads is not None:
B, N, L = decode_q_pe.shape
decode_pe_padded = decode_q_pe.new_empty(
(B, self.q_pad_num_heads, L))
decode_pe_padded.resize_((B, N, L))
decode_pe_padded.copy_(decode_q_pe)
decode_q_pe = decode_pe_padded
if is_rocm_aiter_fp8bmm_enabled(): if is_rocm_aiter_fp8bmm_enabled():
# Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L) # Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L)
decode_ql_nope = aiter_triton_fp8_bmm(decode_q_nope, decode_ql_nope = aiter_triton_fp8_bmm(decode_q_nope,
@ -1567,8 +1589,19 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
group_size=128, group_size=128,
transpose_bm=True) transpose_bm=True)
else: else:
# Pads the head_dim if necessary (for the underlying kernel)
N, B, P = decode_q_nope.shape
_, _, L = self.W_UK_T.shape
if self.q_pad_num_heads is not None:
decode_ql_nope = decode_q_nope.new_empty(
(self.q_pad_num_heads, B, L))
decode_ql_nope.resize_((N, B, L))
else:
decode_ql_nope = decode_q_nope.new_empty((N, B, L))
# Multiply (N, B, P) x (N, P, L) -> (N, B, L) # Multiply (N, B, P) x (N, P, L) -> (N, B, L)
decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T) torch.bmm(decode_q_nope, self.W_UK_T, out=decode_ql_nope)
# Convert from (N, B, L) to (B, N, L) # Convert from (N, B, L) to (B, N, L)
decode_ql_nope = decode_ql_nope.transpose(0, 1) decode_ql_nope = decode_ql_nope.transpose(0, 1)
@ -1603,5 +1636,5 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
attn_out = cp_lse_ag_out_rs(attn_out, lse, get_dcp_group()) attn_out = cp_lse_ag_out_rs(attn_out, lse, get_dcp_group())
# v_up projection # v_up projection
output[:num_decode_tokens] = self._v_up_proj(attn_out) self._v_up_proj(attn_out, out=output[:num_decode_tokens])
return output_padded return output_padded

View File

@ -74,6 +74,8 @@ class SM100Workspace:
g_sm100_workspace = SM100Workspace(128 * 1024 * 1024) # 128MB g_sm100_workspace = SM100Workspace(128 * 1024 * 1024) # 128MB
MAX_HEADS = 128
class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]): class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
can_return_lse_for_decode: bool = True can_return_lse_for_decode: bool = True
@ -92,10 +94,18 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
kv_sharing_target_layer_name: Optional[str], kv_sharing_target_layer_name: Optional[str],
# MLA Specific Arguments # MLA Specific Arguments
**mla_args) -> None: **mla_args) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads, super().__init__(num_heads,
alibi_slopes, sliding_window, kv_cache_dtype, head_size,
logits_soft_cap, attn_type, scale,
kv_sharing_target_layer_name, **mla_args) num_kv_heads,
alibi_slopes,
sliding_window,
kv_cache_dtype,
logits_soft_cap,
attn_type,
kv_sharing_target_layer_name,
q_pad_num_heads=MAX_HEADS,
**mla_args)
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
if any(unsupported_features): if any(unsupported_features):
@ -157,14 +167,6 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
MAX_HEADS = 128 MAX_HEADS = 128
assert H <= MAX_HEADS, f"H must be <= {MAX_HEADS}, but got {H}" assert H <= MAX_HEADS, f"H must be <= {MAX_HEADS}, but got {H}"
if H < MAX_HEADS:
q_nope_padded = q_nope.new_empty((B_q, MAX_HEADS, D_q_nope))
q_nope_padded[:, :H] = q_nope
q_nope = q_nope_padded
q_pe_padded = q_pe.new_empty((B_q, MAX_HEADS, D_q_pe))
q_pe_padded[:, :H] = q_pe
q_pe = q_pe_padded
assert len(page_table.shape) == 2 assert len(page_table.shape) == 2
B_block_table, block_num = page_table.shape B_block_table, block_num = page_table.shape
@ -206,9 +208,9 @@ class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
) )
if H < MAX_HEADS: if H < MAX_HEADS:
# Extract the subsets of the outputs
lse = lse[:, :H] if self.need_to_return_lse_for_decode else lse
out = out[:, :H] out = out[:, :H]
if self.need_to_return_lse_for_decode:
lse = lse[:, :H].contiguous()
return out, lse return out, lse