adapt the mla_v1 with the mla_preprocess kernel (#3397)

### What this PR does / why we need it?

This pull request integrates a new `mla_preprocess` kernel to create an
optimized path for MLA (Multi-Layer Attention) decode operations on
Ascend hardware, controlled by an environment flag. The changes include
new utility functions for weight transformation, a method to prepare
weights for the fused kernel, and logic to route decode-only batches to
this new path. My review identified a critical bug in the `transdata`
utility function where padding dimensions are swapped, which will lead
to incorrect tensor shapes and kernel failures. Additionally, I've
pointed out a high-severity maintainability issue in the
trans_rope_weight function, which modifies its input in-place, and I
have provided a pure-function alternative.

### Does this PR introduce _any_ user-facing change?

No user-facing changes by default. User can enable the `mla_preprocess`
kernel in model by enable the env-var `VLLM_ASCEND_ENABLE_MLAPO`.

### How was this patch tested?

Dedicated Ascend kernels are not covered by our CI yet, so no extra
automated tests were added. Future MLA-focused regression runs will
cover this path.


- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

Signed-off-by: Chen Chen <0109chenchen@gmail.com>
This commit is contained in:
Chen Chen
2025-10-15 10:34:25 +08:00
committed by GitHub
parent 15b2e5c995
commit 16cb3cc45d
2 changed files with 191 additions and 3 deletions

View File

@ -16,11 +16,13 @@ from vllm.model_executor.layers.linear import (LinearBase,
from vllm.utils import cdiv, round_down
from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm_ascend import envs
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
maybe_save_kv_layer_to_connector,
split_decodes_and_prefills,
trans_rope_weight, transdata,
wait_for_kv_layer_from_connector)
from vllm_ascend.compilation.acl_graph import get_graph_params
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
@ -639,6 +641,87 @@ class AscendMLAImpl(MLAAttentionImpl):
# self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29)
# self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29)
if envs.VLLM_ASCEND_ENABLE_MLAPO:
self._process_weights_for_fused_mlapo(act_dtype)
def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype):
kv_a_proj_wt = self.kv_a_proj_with_mqa.weight.data
kv_a_proj_wt = kv_a_proj_wt.t().contiguous()
kv_a_proj_wt = trans_rope_weight(kv_a_proj_wt, self.qk_rope_head_dim)
kv_a_proj_wt = kv_a_proj_wt.t().contiguous()
wd_qkv = torch.cat((kv_a_proj_wt, self.q_a_proj.weight.data), dim=-1)
wd_qkv = wd_qkv.t().contiguous()
wd_qkv = transdata(wd_qkv,
block_size=(16, 32)).unsqueeze(0).contiguous()
self.wd_qkv = torch_npu.npu_format_cast(wd_qkv, 29)
kv_a_proj_deq_scl = self.kv_a_proj_with_mqa.deq_scale
kv_a_proj_deq_scl = kv_a_proj_deq_scl.reshape(
self.kv_lora_rank + self.qk_rope_head_dim, -1).contiguous()
kv_a_proj_deq_scl = trans_rope_weight(kv_a_proj_deq_scl,
self.qk_rope_head_dim)
kv_a_proj_deq_scl = kv_a_proj_deq_scl.view(
self.kv_lora_rank + self.qk_rope_head_dim).contiguous()
self.deq_scale_qkv = torch.cat(
(kv_a_proj_deq_scl, self.q_a_proj.deq_scale), dim=-1).contiguous()
kv_a_proj_qt_bias = self.kv_a_proj_with_mqa.quant_bias
kv_a_proj_qt_bias = kv_a_proj_qt_bias.reshape(
self.kv_lora_rank + self.qk_rope_head_dim, -1).contiguous()
kv_a_proj_qt_bias = trans_rope_weight(kv_a_proj_qt_bias,
self.qk_rope_head_dim)
kv_a_proj_qt_bias = kv_a_proj_qt_bias.view(
self.kv_lora_rank + self.qk_rope_head_dim).contiguous()
self.quant_bias_qkv = torch.cat(
(kv_a_proj_qt_bias, self.q_a_proj.quant_bias),
dim=-1).contiguous()
wu_q = self.q_proj.weight.data
wu_q = wu_q.t().reshape(self.num_heads,
self.qk_nope_head_dim + self.qk_rope_head_dim,
-1)
wu_q = trans_rope_weight(wu_q, self.qk_rope_head_dim)
wu_q = wu_q.reshape(
self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim),
-1)
wu_q = transdata(wu_q, block_size=(16, 32)).unsqueeze(0).contiguous()
self.wu_q = torch_npu.npu_format_cast(wu_q, 29)
qb_deq_scl = self.q_proj.deq_scale.data
qb_deq_scl = qb_deq_scl.reshape(
self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1)
qb_deq_scl = trans_rope_weight(qb_deq_scl, self.qk_rope_head_dim)
self.qb_deq_scl = qb_deq_scl.reshape(
self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim))
qb_qt_bias = self.q_proj.quant_bias.data
qb_qt_bias = qb_qt_bias.reshape(
self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1)
qb_qt_bias = trans_rope_weight(qb_qt_bias, self.qk_rope_head_dim)
self.qb_qt_bias = qb_qt_bias.reshape(
self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim))
device = self.q_a_proj.weight.device
self.gamma0 = torch.ones(
[self.q_a_proj.weight.shape[-1]],
dtype=act_dtype,
device=device,
)
self.beta0 = torch.zeros(
[self.q_a_proj.weight.shape[-1]],
dtype=act_dtype,
device=device,
)
self.gamma1 = self.q_a_layernorm.weight.data
self.beta1 = self.q_a_layernorm.bias.data
self.gamma2 = self.kv_a_layernorm.weight.data
self.quant_scale0 = self.q_a_proj.input_scale.data
self.quant_offset0 = self.q_a_proj.input_offset.data
self.quant_scale1 = self.q_proj.input_scale.data
self.quant_offset1 = self.q_proj.input_offset.data
self.ctkv_scale = torch.tensor([1], dtype=act_dtype, device=device)
self.q_nope_scale = torch.tensor([1], dtype=act_dtype, device=device)
def _compute_prefill_context(
self,
q_nope: torch.Tensor,
@ -961,6 +1044,68 @@ class AscendMLAImpl(MLAAttentionImpl):
current_ms_metadata.before_comm_event.wait()
return self._v_up_proj(attn_output)
def _mla_decode_preprocess(self, hidden_states, kv_cache, attn_metadata):
bsz = attn_metadata.num_decode_tokens
hidden_states = hidden_states[:bsz]
cos_shape = attn_metadata.decode.cos.shape
cos = attn_metadata.decode.cos.view(cos_shape[0], cos_shape[-1])
sin = attn_metadata.decode.sin.view(cos_shape[0], cos_shape[-1])
decode_k_nope, decode_k_pe = kv_cache[0], kv_cache[1]
decode_q_nope = torch.empty(
(hidden_states.shape[0], self.W_UK_T.shape[0],
decode_k_nope.shape[-1]),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
decode_q_pe = torch.empty(
(hidden_states.shape[0], self.W_UK_T.shape[0],
decode_k_pe.shape[-1]),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
torch.ops._C_ascend.mla_preprocess(
hidden_states,
self.gamma0,
self.beta0,
self.wd_qkv,
self.deq_scale_qkv,
self.gamma1,
self.beta1,
self.wu_q,
self.qb_deq_scl,
self.gamma2,
cos,
sin,
self.W_UK_T,
decode_k_nope,
decode_k_pe,
attn_metadata.slot_mapping[:bsz].flatten(),
quant_scale0=self.quant_scale0,
quant_offset0=self.quant_offset0,
bias0=self.quant_bias_qkv,
quant_scale1=self.quant_scale1,
quant_offset1=self.quant_offset1,
bias1=self.qb_qt_bias,
ctkv_scale=self.ctkv_scale,
q_nope_scale=self.q_nope_scale,
cache_mode="krope_ctkv",
quant_mode="per_tensor_quant_asymm",
q_out0=decode_q_nope,
kv_cache_out0=decode_k_nope,
q_out1=decode_q_pe,
kv_cache_out1=decode_k_pe,
)
decode_q_nope = decode_q_nope.view(bsz, self.num_heads,
self.kv_lora_rank)
decode_q_pe = decode_q_pe.view(bsz, self.num_heads, -1)
decode_preprocess_res = DecodeMLAPreprocessResult(
decode_q_nope, decode_q_pe, decode_k_nope, decode_k_pe)
return decode_preprocess_res, None
def _mla_preprocess(self, layer_name, hidden_states, kv_cache,
attn_metadata, need_gather_q_kv):
# MLA Preprocess:
@ -1065,6 +1210,12 @@ class AscendMLAImpl(MLAAttentionImpl):
device=hidden_states.device)
# MLA Preprocess
forward_context = get_forward_context()
if (envs.VLLM_ASCEND_ENABLE_MLAPO and
(attn_metadata is None or not forward_context.with_prefill)):
decode_preprocess_res, prefill_preprocess_res = self._mla_decode_preprocess(
hidden_states, kv_cache, attn_metadata)
else:
decode_preprocess_res, prefill_preprocess_res = self._mla_preprocess(
layer_name, hidden_states, kv_cache, attn_metadata,
need_gather_q_kv)

View File

@ -3,6 +3,7 @@ from dataclasses import dataclass
from typing import Any, List
import torch
import torch.nn.functional as F
import torch_npu
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group,
@ -153,3 +154,39 @@ def version_check():
if full_date >= "20250919":
return True
return False
def round_up(val: int, align: int) -> int:
if align == 0:
return 0
return -(val // -align) * align
def trans_rope_weight(weight, rope_dim):
if rope_dim == 0:
return weight.contiguous()
nope_part = weight[..., :-rope_dim, :]
rope_part = weight[..., -rope_dim:, :]
reordered_rope_part = torch.cat(
(rope_part[..., ::2, :], rope_part[..., 1::2, :]), dim=-2)
return torch.cat((nope_part, reordered_rope_part), dim=-2).contiguous()
def transdata(nd_mat, block_size: tuple = (16, 16)):
r = round_up(nd_mat.shape[0], block_size[0])
c = round_up(nd_mat.shape[1], block_size[1])
r_pad = r - nd_mat.shape[0]
c_pad = c - nd_mat.shape[1]
nd_mat = F.pad(nd_mat, (0, r_pad, 0, c_pad))
nz_mat = torch.permute(
torch.reshape(
nd_mat,
(r // block_size[0], block_size[0], c // block_size[1],
block_size[1]),
),
[2, 0, 1, 3],
)
nz_mat = torch.reshape(
nz_mat,
(nz_mat.shape[0], nz_mat.shape[1] * nz_mat.shape[2], nz_mat.shape[3]))
return nz_mat