[Feat]Qwen3 Moe supports npu_add_rms_norm_quant op by default, update op with bias, resolve conflict with weight prefetch (#3465)

### What this PR does / why we need it?
1.qwen3 moe uses add_rms_norm_quant op instead of 'add_rms_norm op and
quant op' during quantization scene.
2.torch_npu.add_rms_norm_quant op fixed accuracy while model weights is
quantized by anti_method m4, m4 quantization is asymmetric outlier
suppression method, it will generate none-zero norm bias,
add_rms_norm_quant op updated to add this parameter to calculate.
3. add torch-npu check

### Does this PR introduce _any_ user-facing change?
new feature works if torch_npu version >= torch_npu-2.7.1.dev20250919

### How was this patch tested?
1.no special parameters to set, no new envs to set. new feature works if
torch_npu version >= torch_npu-2.7.1.dev20250919
2.use qwen3 moe quantization model to test ,such as
Qwen3-235B-A22B-W8A8, Qwen3-30B-A3B-W8A8,
Qwen3-235B-A22B-Instruct-2507-m4 (anti_method m4)

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: h30027576 <huangdong51@huawei.com>
This commit is contained in:
huangdong2022
2025-10-17 09:30:51 +08:00
committed by GitHub
parent 4c4a8458a5
commit 3a53bbc508
9 changed files with 121 additions and 38 deletions

View File

@ -7,6 +7,7 @@ from vllm.model_executor.layers.layernorm import RMSNorm
from tests.ut.base import PytestBase
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
from vllm_ascend.utils import version_check
def mock_rms_norm(x, weight, eps):
@ -26,6 +27,15 @@ def mock_add_rms_norm_quant(x, residual, weight, quant_scale, quant_offset,
return x_out_quant, None, residual_out_quant
def mock_add_rms_norm_quant_with_bias(x, residual, weight, quant_scale,
quant_offset, beta, epsilon):
x_out = 2 * x
residual_out = 2 * residual
x_out_quant = x_out.to(torch.int8)
residual_out_quant = residual_out.to(torch.int8)
return x_out_quant, None, residual_out_quant
class TestAscendRMSNorm(PytestBase):
@pytest.fixture(autouse=True)
@ -33,8 +43,10 @@ class TestAscendRMSNorm(PytestBase):
mocker.patch("torch_npu.npu_rms_norm", side_effect=mock_rms_norm)
mocker.patch("torch_npu.npu_add_rms_norm",
side_effect=mock_add_rms_norm)
torch_npu_check = version_check()
arnq_side_effect = mock_add_rms_norm_quant_with_bias if torch_npu_check else mock_add_rms_norm_quant
mocker.patch("torch_npu.npu_add_rms_norm_quant",
side_effect=mock_add_rms_norm_quant)
side_effect=arnq_side_effect)
mocker.patch("torch.ops.vllm.maybe_wait_prefetch_done",
side_effect=lambda x: None)
@ -70,8 +82,10 @@ class TestAscendRMSNorm(PytestBase):
mock_model_instance = mocker.MagicMock()
mock_forward_context.model_instance = mock_model_instance
torch_npu_check = version_check()
num_hidden_layers = 3 if torch_npu_check else 2
mock_model_instance.model.layers = [
mocker.MagicMock() for _ in range(2)
mocker.MagicMock() for _ in range(num_hidden_layers)
]
mock_layer_0 = mock_model_instance.model.layers[0]
@ -101,7 +115,7 @@ class TestAscendRMSNorm(PytestBase):
mock_forward_context.addrmsnorm_quant_fusion_enabled = True
mock_forward_context.prefetch_mlp_enabled = False
mock_forward_context.layer_idx = 0
mock_forward_context.num_hidden_layers = 2
mock_forward_context.num_hidden_layers = num_hidden_layers
mock_forward_context.fusion_linear = "gate_up_dense"
# Ensure fusion and layer_idx increment are handled correctly
@ -121,18 +135,37 @@ class TestAscendRMSNorm(PytestBase):
assert mock_forward_context.fusion_linear == "gate_up_dense"
assert mock_forward_context.layer_idx == 1
if torch_npu_check:
mock_forward_context.fusion_linear = "gate_moe"
x_out, residual_out = layer.forward_oot(x, residual)
assert mock_get_forward_context.call_count == 3
assert mock_forward_context.fusion_linear == "qkv_dense"
fusion_linear_expected = "qkv_moe" if torch_npu_check else "qkv_dense"
assert mock_forward_context.fusion_linear == fusion_linear_expected
assert mock_forward_context.layer_idx == 2
x_out, residual_out = layer.forward_oot(x, residual)
assert mock_get_forward_context.call_count == 4
assert mock_forward_context.fusion_linear == "qkv_dense"
fusion_linear_expected = "gate_moe" if torch_npu_check else "qkv_dense"
assert mock_forward_context.fusion_linear == fusion_linear_expected
assert mock_forward_context.layer_idx == 2
if not torch_npu_check:
return
# last layer returned directly
x_out, residual_out = layer.forward_oot(x, residual)
assert mock_get_forward_context.call_count == 5
assert mock_forward_context.fusion_linear == "qkv_moe"
assert mock_forward_context.layer_idx == 3
x_out, residual_out = layer.forward_oot(x, residual)
assert mock_get_forward_context.call_count == 6
assert mock_forward_context.fusion_linear == "qkv_moe"
assert mock_forward_context.layer_idx == 3
if __name__ == '__main__':
unittest.main()

View File

@ -11,7 +11,7 @@ from vllm.forward_context import (BatchDescriptor, get_forward_context,
set_forward_context)
import vllm_ascend.envs as envs_ascend
from vllm_ascend.utils import enable_sp, is_moe_model
from vllm_ascend.utils import enable_sp, is_moe_model, version_check
if TYPE_CHECKING:
from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod
@ -160,13 +160,18 @@ def set_ascend_forward_context(
# this optim now just support dense models due to the specific operators used.
# Once the necessary conditions are met, support for MOE models will also be added.
from vllm_ascend.quantization.quant_config import AscendQuantConfig
model_type_scope = ["llama", "qwen2", "qwen3"]
if version_check():
model_type_scope.append("qwen3_moe")
addrmsnorm_quant_fusion_enabled = isinstance(vllm_config.quant_config, AscendQuantConfig) and \
vllm_config.model_config.hf_config.model_type in ["llama", "qwen2", "qwen3"] and \
vllm_config.model_config.hf_config.model_type in model_type_scope and \
forward_context.layer_idx is not None
if addrmsnorm_quant_fusion_enabled:
forward_context.model_instance = model_instance
forward_context.num_hidden_layers = vllm_config.model_config.hf_config.num_hidden_layers
forward_context.fusion_linear = "gate_up_dense" if forward_context.layer_idx == 0 else "qkv_dense"
if vllm_config.model_config.hf_config.model_type == "qwen3_moe":
forward_context.fusion_linear = "gate_moe" if forward_context.layer_idx == 0 else "qkv_moe"
forward_context.addrmsnorm_quant_fusion_enabled = addrmsnorm_quant_fusion_enabled
if num_tokens is None and attn_metadata is not None:

View File

@ -33,13 +33,12 @@ from vllm.v1.kv_cache_interface import AttentionSpec
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
maybe_save_kv_layer_to_connector,
version_check,
wait_for_kv_layer_from_connector)
from vllm_ascend.compilation.acl_graph import (get_graph_params,
update_graph_params_workspaces)
from vllm_ascend.ops.attention import vanilla_chunked_prefill
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
nd_to_nz_2d, nd_to_nz_spec)
nd_to_nz_2d, nd_to_nz_spec, version_check)
from ..utils import weak_ref_tensors

View File

@ -1,10 +1,8 @@
import functools
from dataclasses import dataclass
from typing import Any, List
import torch
import torch.nn.functional as F
import torch_npu
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group,
is_v1_kv_transfer_group)
@ -142,20 +140,6 @@ def maybe_save_kv_layer_to_connector(
connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata)
@functools.cache
def version_check():
import re
torch_npu_version = torch_npu.version.__version__
date_pattern = r'dev(\d{8})'
match = re.search(date_pattern, torch_npu_version)
if match:
full_date = match.group(1)
if full_date >= "20250919":
return True
return False
def round_up(val: int, align: int) -> int:
if align == 0:
return 0

View File

@ -18,7 +18,7 @@ from vllm.forward_context import BatchDescriptor, get_forward_context
from vllm.logger import logger
from vllm.platforms import current_platform
from vllm_ascend.attention.utils import version_check
from vllm_ascend.utils import version_check
from ..utils import weak_ref_tensors

View File

@ -18,21 +18,36 @@
from typing import Optional, Tuple, Union, cast
import torch
from vllm.config import get_current_vllm_config
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
from vllm_ascend.utils import version_check
def _addrmsnorm_forward_oot(
self,
x: torch.Tensor,
residual: torch.Tensor,
layer: Optional[torch.nn.Module] = None,
bias: Optional[torch.nn.Parameter] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
import torch_npu
from vllm_ascend.utils import is_310p
torch_npu_check = version_check()
if layer is not None and not is_310p():
if torch_npu_check:
x, _, residual = torch_npu.npu_add_rms_norm_quant(
x,
residual,
self.weight,
layer.aclnn_input_scale,
layer.aclnn_input_offset,
beta=bias,
epsilon=self.variance_epsilon)
else:
x, _, residual = torch_npu.npu_add_rms_norm_quant(
x,
residual,
@ -50,12 +65,32 @@ def _addrmsnorm_forward_oot(
else:
x, _, residual = torch_npu.npu_add_rms_norm(
x, residual, self.weight, self.variance_epsilon)
if torch_npu_check and bias is not None:
x.add_(bias)
torch.ops.vllm.maybe_wait_prefetch_done(x)
return x, residual
class AscendRMSNorm(RMSNorm):
def __init__(
self,
hidden_size: int,
eps: float = 1e-6,
var_hidden_size: Optional[int] = None,
has_weight: bool = True,
dtype: Optional[torch.dtype] = None,
) -> None:
super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype)
vllm_config = get_current_vllm_config()
self.bias = None
self.torch_npu_check = version_check()
# quantization with anti_method m4 will generate none-zero norm bias
if self.torch_npu_check and vllm_config.quant_config is not None and \
any("norm.bias" in name for name in vllm_config.quant_config.quant_description.keys()):
self.bias = torch.nn.Parameter(torch.zeros(hidden_size),
requires_grad=False)
def forward_oot(
self,
x: torch.Tensor,
@ -66,10 +101,13 @@ class AscendRMSNorm(RMSNorm):
if residual is not None:
assert x.size(0) == residual.size(0)
x, residual = _addrmsnorm_forward_oot(
self, x, residual, self.next_need_quant_fusion_linear)
self, x, residual, self.next_need_quant_fusion_linear,
self.bias)
return x, residual
x, residual = torch_npu.npu_rms_norm(x, self.weight,
self.variance_epsilon)
if self.torch_npu_check and self.bias is not None:
x.add_(self.bias)
return x
@property
@ -99,6 +137,13 @@ class AscendRMSNorm(RMSNorm):
# does not need to be repeated
if not forward_context.prefetch_mlp_enabled:
forward_context.layer_idx += 1
elif fusion_linear == "qkv_moe":
next_linear = model_instance.model.layers[
layer_idx].self_attn.qkv_proj
forward_context.fusion_linear = "gate_moe"
elif fusion_linear == "gate_moe":
forward_context.fusion_linear = "qkv_moe"
forward_context.layer_idx += 1
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
if next_linear is not None and \
not isinstance(next_linear.quant_method.quant_method, AscendW8A8LinearMethod):

View File

@ -177,7 +177,6 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
group_type=0,
group_list=group_list,
output_dtype=_output_dtype)[0]
return hidden_states

View File

@ -7,6 +7,7 @@ from vllm.forward_context import get_forward_context
from vllm_ascend.ascend_config import WeightPrefetchConfig
from vllm_ascend.ops.linear import (AscendQKVParallelLinear,
AscendRowParallelLinear)
from vllm_ascend.utils import version_check
SUPPORTED_MODULES = ["attn", "mlp", "moe"]
MOE_PREFETCH_TOKEN_THRESHOLD = 96
@ -82,14 +83,15 @@ class WeightPrefetchMethod:
if not self.moe.is_active_this_forward:
return
forward_context = get_forward_context()
if not version_check():
forward_context.layer_idx += 1
weight = forward_context.model_instance.model.layers[
forward_context.layer_idx].mlp.experts.w13_weight
forward_context.layer_idx - 1].mlp.experts.w13_weight
weight_size = weight.data.element_size() * weight.data.numel(
) * self.moe.prefetch_ratio.get(prefix, 0)
torch.ops.vllm.prefetch_preprocess(weight=weight,
start_flag=None,
max_weight_size=int(weight_size))
forward_context.layer_idx += 1
def maybe_prefetch_moe_weight_postprocess(self, stop_flag: torch.Tensor):
if not self.moe.is_active_this_forward:

View File

@ -546,7 +546,8 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
if vllm_config is not None and \
vllm_config.quant_config is not None and \
any("norm.bias" in name for name in vllm_config.quant_config.quant_description.keys()):
any("norm.bias" in name for name in vllm_config.quant_config.quant_description.keys()) and \
not version_check():
REGISTERED_ASCEND_OPS["RMSNorm"] = AscendQuantRMSNorm
for name, op_cls in REGISTERED_ASCEND_OPS.items():
@ -725,3 +726,18 @@ def calculate_dp_buffer_size() -> int:
def is_hierarchical_communication_enabled():
return (os.getenv("HCCL_INTRA_ROCE_ENABLE", "") == "0"
and os.getenv("HCCL_INTRA_PCIE_ENABLE", "") == "1")
@functools.cache
def version_check():
"""check if torch_npu version >= dev20250919"""
import re
torch_npu_version = torch_npu.version.__version__
date_pattern = r'dev(\d{8})'
match = re.search(date_pattern, torch_npu_version)
if match:
full_date = match.group(1)
if full_date >= "20250919":
return True
return False