mirror of
https://github.com/vllm-project/vllm-ascend.git
synced 2025-10-20 13:43:53 +08:00
[Feat]Qwen3 Moe supports npu_add_rms_norm_quant op by default, update op with bias, resolve conflict with weight prefetch (#3465)
### What this PR does / why we need it? 1.qwen3 moe uses add_rms_norm_quant op instead of 'add_rms_norm op and quant op' during quantization scene. 2.torch_npu.add_rms_norm_quant op fixed accuracy while model weights is quantized by anti_method m4, m4 quantization is asymmetric outlier suppression method, it will generate none-zero norm bias, add_rms_norm_quant op updated to add this parameter to calculate. 3. add torch-npu check ### Does this PR introduce _any_ user-facing change? new feature works if torch_npu version >= torch_npu-2.7.1.dev20250919 ### How was this patch tested? 1.no special parameters to set, no new envs to set. new feature works if torch_npu version >= torch_npu-2.7.1.dev20250919 2.use qwen3 moe quantization model to test ,such as Qwen3-235B-A22B-W8A8, Qwen3-30B-A3B-W8A8, Qwen3-235B-A22B-Instruct-2507-m4 (anti_method m4) - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: h30027576 <huangdong51@huawei.com>
This commit is contained in:
@ -7,6 +7,7 @@ from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
|
||||
from tests.ut.base import PytestBase
|
||||
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
|
||||
from vllm_ascend.utils import version_check
|
||||
|
||||
|
||||
def mock_rms_norm(x, weight, eps):
|
||||
@ -26,6 +27,15 @@ def mock_add_rms_norm_quant(x, residual, weight, quant_scale, quant_offset,
|
||||
return x_out_quant, None, residual_out_quant
|
||||
|
||||
|
||||
def mock_add_rms_norm_quant_with_bias(x, residual, weight, quant_scale,
|
||||
quant_offset, beta, epsilon):
|
||||
x_out = 2 * x
|
||||
residual_out = 2 * residual
|
||||
x_out_quant = x_out.to(torch.int8)
|
||||
residual_out_quant = residual_out.to(torch.int8)
|
||||
return x_out_quant, None, residual_out_quant
|
||||
|
||||
|
||||
class TestAscendRMSNorm(PytestBase):
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
@ -33,8 +43,10 @@ class TestAscendRMSNorm(PytestBase):
|
||||
mocker.patch("torch_npu.npu_rms_norm", side_effect=mock_rms_norm)
|
||||
mocker.patch("torch_npu.npu_add_rms_norm",
|
||||
side_effect=mock_add_rms_norm)
|
||||
torch_npu_check = version_check()
|
||||
arnq_side_effect = mock_add_rms_norm_quant_with_bias if torch_npu_check else mock_add_rms_norm_quant
|
||||
mocker.patch("torch_npu.npu_add_rms_norm_quant",
|
||||
side_effect=mock_add_rms_norm_quant)
|
||||
side_effect=arnq_side_effect)
|
||||
mocker.patch("torch.ops.vllm.maybe_wait_prefetch_done",
|
||||
side_effect=lambda x: None)
|
||||
|
||||
@ -70,8 +82,10 @@ class TestAscendRMSNorm(PytestBase):
|
||||
|
||||
mock_model_instance = mocker.MagicMock()
|
||||
mock_forward_context.model_instance = mock_model_instance
|
||||
torch_npu_check = version_check()
|
||||
num_hidden_layers = 3 if torch_npu_check else 2
|
||||
mock_model_instance.model.layers = [
|
||||
mocker.MagicMock() for _ in range(2)
|
||||
mocker.MagicMock() for _ in range(num_hidden_layers)
|
||||
]
|
||||
|
||||
mock_layer_0 = mock_model_instance.model.layers[0]
|
||||
@ -101,7 +115,7 @@ class TestAscendRMSNorm(PytestBase):
|
||||
mock_forward_context.addrmsnorm_quant_fusion_enabled = True
|
||||
mock_forward_context.prefetch_mlp_enabled = False
|
||||
mock_forward_context.layer_idx = 0
|
||||
mock_forward_context.num_hidden_layers = 2
|
||||
mock_forward_context.num_hidden_layers = num_hidden_layers
|
||||
mock_forward_context.fusion_linear = "gate_up_dense"
|
||||
|
||||
# Ensure fusion and layer_idx increment are handled correctly
|
||||
@ -121,18 +135,37 @@ class TestAscendRMSNorm(PytestBase):
|
||||
assert mock_forward_context.fusion_linear == "gate_up_dense"
|
||||
assert mock_forward_context.layer_idx == 1
|
||||
|
||||
if torch_npu_check:
|
||||
mock_forward_context.fusion_linear = "gate_moe"
|
||||
x_out, residual_out = layer.forward_oot(x, residual)
|
||||
|
||||
assert mock_get_forward_context.call_count == 3
|
||||
assert mock_forward_context.fusion_linear == "qkv_dense"
|
||||
fusion_linear_expected = "qkv_moe" if torch_npu_check else "qkv_dense"
|
||||
assert mock_forward_context.fusion_linear == fusion_linear_expected
|
||||
assert mock_forward_context.layer_idx == 2
|
||||
|
||||
x_out, residual_out = layer.forward_oot(x, residual)
|
||||
|
||||
assert mock_get_forward_context.call_count == 4
|
||||
assert mock_forward_context.fusion_linear == "qkv_dense"
|
||||
fusion_linear_expected = "gate_moe" if torch_npu_check else "qkv_dense"
|
||||
assert mock_forward_context.fusion_linear == fusion_linear_expected
|
||||
assert mock_forward_context.layer_idx == 2
|
||||
|
||||
if not torch_npu_check:
|
||||
return
|
||||
# last layer returned directly
|
||||
x_out, residual_out = layer.forward_oot(x, residual)
|
||||
|
||||
assert mock_get_forward_context.call_count == 5
|
||||
assert mock_forward_context.fusion_linear == "qkv_moe"
|
||||
assert mock_forward_context.layer_idx == 3
|
||||
|
||||
x_out, residual_out = layer.forward_oot(x, residual)
|
||||
|
||||
assert mock_get_forward_context.call_count == 6
|
||||
assert mock_forward_context.fusion_linear == "qkv_moe"
|
||||
assert mock_forward_context.layer_idx == 3
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
@ -11,7 +11,7 @@ from vllm.forward_context import (BatchDescriptor, get_forward_context,
|
||||
set_forward_context)
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.utils import enable_sp, is_moe_model
|
||||
from vllm_ascend.utils import enable_sp, is_moe_model, version_check
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod
|
||||
@ -160,13 +160,18 @@ def set_ascend_forward_context(
|
||||
# this optim now just support dense models due to the specific operators used.
|
||||
# Once the necessary conditions are met, support for MOE models will also be added.
|
||||
from vllm_ascend.quantization.quant_config import AscendQuantConfig
|
||||
model_type_scope = ["llama", "qwen2", "qwen3"]
|
||||
if version_check():
|
||||
model_type_scope.append("qwen3_moe")
|
||||
addrmsnorm_quant_fusion_enabled = isinstance(vllm_config.quant_config, AscendQuantConfig) and \
|
||||
vllm_config.model_config.hf_config.model_type in ["llama", "qwen2", "qwen3"] and \
|
||||
vllm_config.model_config.hf_config.model_type in model_type_scope and \
|
||||
forward_context.layer_idx is not None
|
||||
if addrmsnorm_quant_fusion_enabled:
|
||||
forward_context.model_instance = model_instance
|
||||
forward_context.num_hidden_layers = vllm_config.model_config.hf_config.num_hidden_layers
|
||||
forward_context.fusion_linear = "gate_up_dense" if forward_context.layer_idx == 0 else "qkv_dense"
|
||||
if vllm_config.model_config.hf_config.model_type == "qwen3_moe":
|
||||
forward_context.fusion_linear = "gate_moe" if forward_context.layer_idx == 0 else "qkv_moe"
|
||||
forward_context.addrmsnorm_quant_fusion_enabled = addrmsnorm_quant_fusion_enabled
|
||||
|
||||
if num_tokens is None and attn_metadata is not None:
|
||||
|
@ -33,13 +33,12 @@ from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
maybe_save_kv_layer_to_connector,
|
||||
version_check,
|
||||
wait_for_kv_layer_from_connector)
|
||||
from vllm_ascend.compilation.acl_graph import (get_graph_params,
|
||||
update_graph_params_workspaces)
|
||||
from vllm_ascend.ops.attention import vanilla_chunked_prefill
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
|
||||
nd_to_nz_2d, nd_to_nz_spec)
|
||||
nd_to_nz_2d, nd_to_nz_spec, version_check)
|
||||
|
||||
from ..utils import weak_ref_tensors
|
||||
|
||||
|
@ -1,10 +1,8 @@
|
||||
import functools
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, List
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch_npu
|
||||
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||
has_kv_transfer_group,
|
||||
is_v1_kv_transfer_group)
|
||||
@ -142,20 +140,6 @@ def maybe_save_kv_layer_to_connector(
|
||||
connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata)
|
||||
|
||||
|
||||
@functools.cache
|
||||
def version_check():
|
||||
import re
|
||||
torch_npu_version = torch_npu.version.__version__
|
||||
date_pattern = r'dev(\d{8})'
|
||||
|
||||
match = re.search(date_pattern, torch_npu_version)
|
||||
if match:
|
||||
full_date = match.group(1)
|
||||
if full_date >= "20250919":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def round_up(val: int, align: int) -> int:
|
||||
if align == 0:
|
||||
return 0
|
||||
|
@ -18,7 +18,7 @@ from vllm.forward_context import BatchDescriptor, get_forward_context
|
||||
from vllm.logger import logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from vllm_ascend.attention.utils import version_check
|
||||
from vllm_ascend.utils import version_check
|
||||
|
||||
from ..utils import weak_ref_tensors
|
||||
|
||||
|
@ -18,21 +18,36 @@
|
||||
from typing import Optional, Tuple, Union, cast
|
||||
|
||||
import torch
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
|
||||
|
||||
from vllm_ascend.utils import version_check
|
||||
|
||||
|
||||
def _addrmsnorm_forward_oot(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
layer: Optional[torch.nn.Module] = None,
|
||||
bias: Optional[torch.nn.Parameter] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
import torch_npu
|
||||
|
||||
from vllm_ascend.utils import is_310p
|
||||
|
||||
torch_npu_check = version_check()
|
||||
if layer is not None and not is_310p():
|
||||
if torch_npu_check:
|
||||
x, _, residual = torch_npu.npu_add_rms_norm_quant(
|
||||
x,
|
||||
residual,
|
||||
self.weight,
|
||||
layer.aclnn_input_scale,
|
||||
layer.aclnn_input_offset,
|
||||
beta=bias,
|
||||
epsilon=self.variance_epsilon)
|
||||
else:
|
||||
x, _, residual = torch_npu.npu_add_rms_norm_quant(
|
||||
x,
|
||||
residual,
|
||||
@ -50,12 +65,32 @@ def _addrmsnorm_forward_oot(
|
||||
else:
|
||||
x, _, residual = torch_npu.npu_add_rms_norm(
|
||||
x, residual, self.weight, self.variance_epsilon)
|
||||
if torch_npu_check and bias is not None:
|
||||
x.add_(bias)
|
||||
torch.ops.vllm.maybe_wait_prefetch_done(x)
|
||||
return x, residual
|
||||
|
||||
|
||||
class AscendRMSNorm(RMSNorm):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
eps: float = 1e-6,
|
||||
var_hidden_size: Optional[int] = None,
|
||||
has_weight: bool = True,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> None:
|
||||
super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype)
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.bias = None
|
||||
self.torch_npu_check = version_check()
|
||||
# quantization with anti_method m4 will generate none-zero norm bias
|
||||
if self.torch_npu_check and vllm_config.quant_config is not None and \
|
||||
any("norm.bias" in name for name in vllm_config.quant_config.quant_description.keys()):
|
||||
self.bias = torch.nn.Parameter(torch.zeros(hidden_size),
|
||||
requires_grad=False)
|
||||
|
||||
def forward_oot(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
@ -66,10 +101,13 @@ class AscendRMSNorm(RMSNorm):
|
||||
if residual is not None:
|
||||
assert x.size(0) == residual.size(0)
|
||||
x, residual = _addrmsnorm_forward_oot(
|
||||
self, x, residual, self.next_need_quant_fusion_linear)
|
||||
self, x, residual, self.next_need_quant_fusion_linear,
|
||||
self.bias)
|
||||
return x, residual
|
||||
x, residual = torch_npu.npu_rms_norm(x, self.weight,
|
||||
self.variance_epsilon)
|
||||
if self.torch_npu_check and self.bias is not None:
|
||||
x.add_(self.bias)
|
||||
return x
|
||||
|
||||
@property
|
||||
@ -99,6 +137,13 @@ class AscendRMSNorm(RMSNorm):
|
||||
# does not need to be repeated
|
||||
if not forward_context.prefetch_mlp_enabled:
|
||||
forward_context.layer_idx += 1
|
||||
elif fusion_linear == "qkv_moe":
|
||||
next_linear = model_instance.model.layers[
|
||||
layer_idx].self_attn.qkv_proj
|
||||
forward_context.fusion_linear = "gate_moe"
|
||||
elif fusion_linear == "gate_moe":
|
||||
forward_context.fusion_linear = "qkv_moe"
|
||||
forward_context.layer_idx += 1
|
||||
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
|
||||
if next_linear is not None and \
|
||||
not isinstance(next_linear.quant_method.quant_method, AscendW8A8LinearMethod):
|
||||
|
@ -177,7 +177,6 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=_output_dtype)[0]
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
|
@ -7,6 +7,7 @@ from vllm.forward_context import get_forward_context
|
||||
from vllm_ascend.ascend_config import WeightPrefetchConfig
|
||||
from vllm_ascend.ops.linear import (AscendQKVParallelLinear,
|
||||
AscendRowParallelLinear)
|
||||
from vllm_ascend.utils import version_check
|
||||
|
||||
SUPPORTED_MODULES = ["attn", "mlp", "moe"]
|
||||
MOE_PREFETCH_TOKEN_THRESHOLD = 96
|
||||
@ -82,14 +83,15 @@ class WeightPrefetchMethod:
|
||||
if not self.moe.is_active_this_forward:
|
||||
return
|
||||
forward_context = get_forward_context()
|
||||
if not version_check():
|
||||
forward_context.layer_idx += 1
|
||||
weight = forward_context.model_instance.model.layers[
|
||||
forward_context.layer_idx].mlp.experts.w13_weight
|
||||
forward_context.layer_idx - 1].mlp.experts.w13_weight
|
||||
weight_size = weight.data.element_size() * weight.data.numel(
|
||||
) * self.moe.prefetch_ratio.get(prefix, 0)
|
||||
torch.ops.vllm.prefetch_preprocess(weight=weight,
|
||||
start_flag=None,
|
||||
max_weight_size=int(weight_size))
|
||||
forward_context.layer_idx += 1
|
||||
|
||||
def maybe_prefetch_moe_weight_postprocess(self, stop_flag: torch.Tensor):
|
||||
if not self.moe.is_active_this_forward:
|
||||
|
@ -546,7 +546,8 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
|
||||
|
||||
if vllm_config is not None and \
|
||||
vllm_config.quant_config is not None and \
|
||||
any("norm.bias" in name for name in vllm_config.quant_config.quant_description.keys()):
|
||||
any("norm.bias" in name for name in vllm_config.quant_config.quant_description.keys()) and \
|
||||
not version_check():
|
||||
REGISTERED_ASCEND_OPS["RMSNorm"] = AscendQuantRMSNorm
|
||||
|
||||
for name, op_cls in REGISTERED_ASCEND_OPS.items():
|
||||
@ -725,3 +726,18 @@ def calculate_dp_buffer_size() -> int:
|
||||
def is_hierarchical_communication_enabled():
|
||||
return (os.getenv("HCCL_INTRA_ROCE_ENABLE", "") == "0"
|
||||
and os.getenv("HCCL_INTRA_PCIE_ENABLE", "") == "1")
|
||||
|
||||
|
||||
@functools.cache
|
||||
def version_check():
|
||||
"""check if torch_npu version >= dev20250919"""
|
||||
import re
|
||||
torch_npu_version = torch_npu.version.__version__
|
||||
date_pattern = r'dev(\d{8})'
|
||||
|
||||
match = re.search(date_pattern, torch_npu_version)
|
||||
if match:
|
||||
full_date = match.group(1)
|
||||
if full_date >= "20250919":
|
||||
return True
|
||||
return False
|
||||
|
Reference in New Issue
Block a user