diff --git a/.github/workflows/_e2e_test.yaml b/.github/workflows/_e2e_test.yaml index c50894d37..1254f3a2f 100644 --- a/.github/workflows/_e2e_test.yaml +++ b/.github/workflows/_e2e_test.yaml @@ -173,7 +173,6 @@ jobs: if: ${{ inputs.type == 'full' }} run: | pytest -sv tests/e2e/multicard/test_data_parallel.py - pytest -sv tests/e2e/multicard/test_full_graph_mode.py pytest -sv tests/e2e/multicard/test_expert_parallel.py # external_launcher test is not stable enough. Fix it later # pytest -sv tests/e2e/multicard/test_external_launcher.py diff --git a/README.md b/README.md index 811b0ce08..9c255b1f1 100644 --- a/README.md +++ b/README.md @@ -43,7 +43,7 @@ By using vLLM Ascend plugin, popular open-source models, including Transformer-l - Software: * Python >= 3.9, < 3.12 * CANN >= 8.2.rc1 (Ascend HDK version refers to [here](https://www.hiascend.com/document/detail/zh/canncommercial/82RC1/releasenote/releasenote_0000.html)) - * PyTorch >= 2.7.1, torch-npu >= 2.7.1.dev20250919 + * PyTorch >= 2.7.1, torch-npu >= 2.7.1.dev20250724 * vLLM (the same version as vllm-ascend) ## Getting Started diff --git a/README.zh.md b/README.zh.md index cd312fbf9..bb7ddb93f 100644 --- a/README.zh.md +++ b/README.zh.md @@ -44,7 +44,7 @@ vLLM 昇腾插件 (`vllm-ascend`) 是一个由社区维护的让vLLM在Ascend NP - 软件: * Python >= 3.9, < 3.12 * CANN >= 8.2.rc1 (Ascend HDK 版本参考[这里](https://www.hiascend.com/document/detail/zh/canncommercial/82RC1/releasenote/releasenote_0000.html)) - * PyTorch >= 2.7.1, torch-npu >= 2.7.1.dev20250919 + * PyTorch >= 2.7.1, torch-npu >= 2.7.1.dev20250724 * vLLM (与vllm-ascend版本一致) ## 开始使用 diff --git a/docs/source/installation.md b/docs/source/installation.md index 40716c86d..0d3b54da2 100644 --- a/docs/source/installation.md +++ b/docs/source/installation.md @@ -13,7 +13,7 @@ This document describes how to install vllm-ascend manually. |---------------|----------------------------------|-------------------------------------------| | Ascend HDK | Refer to [here](https://www.hiascend.com/document/detail/zh/canncommercial/82RC1/releasenote/releasenote_0000.html) | Required for CANN | | CANN | >= 8.2.RC1 | Required for vllm-ascend and torch-npu | - | torch-npu | >= 2.7.1.dev20250919 | Required for vllm-ascend, No need to install manually, it will be auto installed in below steps | + | torch-npu | >= 2.7.1.dev20250724 | Required for vllm-ascend, No need to install manually, it will be auto installed in below steps | | torch | >= 2.7.1 | Required for torch-npu and vllm | You have 2 way to install: diff --git a/pyproject.toml b/pyproject.toml index 479fbac15..1a140ce87 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ requires = [ "scipy", "setuptools>=64", "setuptools-scm>=8", - "torch-npu==2.7.1.dev20250919", + "torch-npu==2.7.1.dev20250724", "torch>=2.7.1", "torchvision", "wheel", diff --git a/requirements.txt b/requirements.txt index ef5b05c6f..7808e8525 100644 --- a/requirements.txt +++ b/requirements.txt @@ -24,4 +24,4 @@ numba # Install torch_npu --pre --extra-index-url https://mirrors.huaweicloud.com/ascend/repos/pypi -torch-npu==2.7.1.dev20250919 +torch-npu==2.7.1.dev20250724 diff --git a/tests/e2e/multicard/test_full_graph_mode.py b/tests/e2e/multicard/test_full_graph_mode.py deleted file mode 100644 index 6105ef70b..000000000 --- a/tests/e2e/multicard/test_full_graph_mode.py +++ /dev/null @@ -1,103 +0,0 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# Copyright 2023 The vLLM team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# This file is a part of the vllm-ascend project. -# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py -# -"""Compare the short outputs of HF and vLLM when using greedy sampling. - -Run `pytest tests/e2e/multicard/test_qwen3_moe.py`. -""" - -import os - -from vllm import SamplingParams - -from tests.e2e.conftest import VllmRunner -from tests.e2e.model_utils import check_outputs_equal - - -def test_models_distributed_Qwen3_MOE_TP2_WITH_FULLGRAPH(): - if 'HCCL_OP_EXPANSION_MODE' in os.environ: - del os.environ['HCCL_OP_EXPANSION_MODE'] - prompts = [ - ('Solve the following math problem step by step.' - 'The last line of your response should be of the form Answer: ' - '$Answer (without quotes) where $Answer is the answer to the problem.\n\n' - 'In triangle $ABC$, $\\sin \\angle A = \\frac{4}{5}$ and $\\angle A < 90^\\circ$. Let $D$' - 'be a point outside triangle $ABC$ such that $\\angle BAD = \\angle DAC$,' - '$\\angle BDC = 90^\\circ$. Suppose $AD = 1$ and $\\frac{BD}{CD} = \\frac{3}{2}$.' - 'If $AB + AC$ can be expressed in the form $\\frac{a\\sqrt{b}}{c}$,' - 'where $a, b, c$ are pairwise relatively prime integers, find $a + b + c$.' - ), - ('Solve the following math problem step by step.' - 'The last line of your response should be of the form Answer: ' - '$Answer (without quotes) where $Answer is the answer to the problem.\n\n' - 'Let $ABCD$ be a unit square in the plane. Points $X$ and $Y$ are chosen' - 'independently and uniformly at random on the perimeter of $ABCD$.' - 'If the expected value of the area of triangle $\\triangle AXY$' - 'can be expressed as $\\frac{m}{n}$, for relatively prime positive' - 'integers $m$ and $n$, compute $m+n$.'), - ('Solve the following math problem step by step.' - 'The last line of your response should be of the form Answer: ' - '$Answer (without quotes) where $Answer is the answer to the problem.\n\n' - 'Let $a, b, c$ be distinct numbers such that the equations $x^2 + ax + 1 = 0$' - 'and $x^2 + bx + c = 0$ have a common real root, and the equations $x^2 + x + a = 0$' - 'and $x^2 + cx + b = 0$ also have a common real root.' - 'Compute the sum $a + b + c$.') - ] - model = "Qwen/Qwen3-30B-A3B" - sampling_params = SamplingParams(max_tokens=5, - n=1, - temperature=0.0, - top_p=1.0, - top_k=1) - with VllmRunner(model, - max_model_len=1024, - tensor_parallel_size=2, - enforce_eager=False, - gpu_memory_utilization=0.95, - compilation_config={ - "cudagraph_capture_sizes": - [4, 8, 12, 16, 24, 32, 40, 48], - "cudagraph_mode": "FULL_DECODE_ONLY" - }) as runner: - vllm_fullgraph_outputs = runner.model.generate(prompts, - sampling_params) - with VllmRunner( - model, - max_model_len=1024, - tensor_parallel_size=2, - enforce_eager=True, - gpu_memory_utilization=0.95, - ) as runner: - vllm_eager_outputs = runner.model.generate(prompts, sampling_params) - - vllm_fullgraph_outputs_list = [] - for output in vllm_fullgraph_outputs: - vllm_fullgraph_outputs_list.append( - (output.outputs[0].index, output.outputs[0].text)) - - vllm_eager_outputs_list = [] - for output in vllm_eager_outputs: - vllm_eager_outputs_list.append( - (output.outputs[0].index, output.outputs[0].text)) - - check_outputs_equal( - outputs_0_lst=vllm_eager_outputs_list, - outputs_1_lst=vllm_fullgraph_outputs_list, - name_0="vllm_eager_outputs", - name_1="vllm_fullgraph_outputs", - ) diff --git a/tests/ut/attention/test_attention_v1.py b/tests/ut/attention/test_attention_v1.py index 678b0bbe2..d553637d9 100644 --- a/tests/ut/attention/test_attention_v1.py +++ b/tests/ut/attention/test_attention_v1.py @@ -405,109 +405,6 @@ class TestAscendAttentionBackendImpl(TestBase): mock_paged_attention.assert_called_once() assert output.shape == (10, 8 * 64) - @patch('vllm_ascend.attention.attention_v1.get_forward_context') - @patch('vllm_ascend.attention.attention_v1.get_graph_params') - @patch('torch_npu._npu_reshape_and_cache') - @patch('torch_npu._npu_paged_attention') - @patch('torch.npu.graph_task_group_end') - @patch('torch.npu.graph_task_group_begin') - @patch('torch.npu.ExternalEvent') - @patch('torch_npu.npu.current_stream') - def test_paged_attention_with_existing_workspace( - self, - mock_get_forward_context, - mock_get_graph_params, - mock_npu_reshape_and_cache, - mock_paged_attention, - mock_graph_begin, - mock_graph_end, - mock_external_event_class, - mock_current_stream, - ): - graph_params = MagicMock() - attn_metadata = MagicMock() - num_tokens = 10 - - graph_params.workspaces = {num_tokens: 10} - graph_params.events = {num_tokens: []} - graph_params.attn_params = {num_tokens: []} - graph_params.handles = {num_tokens: []} - - query = torch.randn(2, 5, 8) # [batch_size, seq_len, hidden_size] - key_cache = MagicMock() - value_cache = MagicMock() - num_kv_heads = 4 - num_heads = 8 - scale = 0.1 - output = torch.randn(2, 5, 8) - - self_obj = MagicMock() - self_obj.key_cache = key_cache - self_obj.value_cache = value_cache - self_obj.num_kv_heads = num_kv_heads - self_obj.num_heads = num_heads - self_obj.scale = scale - - mock_stream = MagicMock() - mock_current_stream.return_value = mock_stream - mock_event_instance = MagicMock() - mock_external_event_class.return_value = mock_event_instance - - mock_handle = MagicMock() - mock_graph_end.return_value = mock_handle - - workspace = graph_params.workspaces.get(num_tokens) - self.assertEqual(workspace, 10) - - # 2. Handle graph capturing mode - stream = mock_current_stream() - event = mock_external_event_class() - event.wait(stream) - event.reset(stream) - graph_params.events[num_tokens].append(event) - graph_params.attn_params[num_tokens].append(( - query, - self_obj.key_cache, - self_obj.value_cache, - self_obj.num_kv_heads, - self_obj.num_heads, - self_obj.scale, - attn_metadata.block_tables, - attn_metadata.seq_lens, - output, - )) - - mock_event_instance.wait.assert_called_once_with(mock_stream) - mock_event_instance.reset.assert_called_once_with(mock_stream) - self.assertEqual(len(graph_params.events[num_tokens]), 1) - self.assertEqual(len(graph_params.attn_params[num_tokens]), 1) - - query = torch.randn(10, 8 * 64) - key = torch.randn(10, 8 * 64) - value = torch.randn(10, 8 * 64) - kv_cache = torch.empty(2, 5, 128, 8, 64) - metadata = self.attn_metadata - metadata.attn_state = AscendAttentionState.DecodeOnly - metadata.seq_lens = torch.tensor([10]) - metadata.block_tables = torch.zeros(1, 5, dtype=torch.long) - metadata.num_actual_tokens = 10 - metadata.slot_mapping = torch.zeros(10, dtype=torch.long) - layer = self.layer_no_quant - - mock_get_forward_context.return_value = MagicMock(capturing=True) - mock_get_graph_params.return_value = graph_params - - output = self.impl.forward(layer, - query, - key, - value, - kv_cache, - metadata, - trace_flag=False) - - mock_paged_attention.assert_called_once() - self.assertEqual(len(graph_params.handles[num_tokens]), 0) - @patch('torch_npu._npu_reshape_and_cache') @patch('torch_npu.npu_fused_infer_attention_score') def test_forward_decode_only_swa(self, mock_fused_infer_attention_score, diff --git a/tests/ut/ops/test_layernorm.py b/tests/ut/ops/test_layernorm.py index dd99088ce..b0c05a203 100644 --- a/tests/ut/ops/test_layernorm.py +++ b/tests/ut/ops/test_layernorm.py @@ -24,7 +24,7 @@ def mock_add_rms_norm(x, residual, weight, eps): def mock_add_rms_norm_quant(x, residual, weight, quant_scale, quant_offset, - beta, epsilon): + epsilon): x_out = 2 * x residual_out = 2 * residual x_out_quant = x_out.to(torch.int8) @@ -94,7 +94,7 @@ class TestAscendRMSNorm(PytestBase): mock_model_instance = mocker.MagicMock() mock_forward_context.model_instance = mock_model_instance mock_model_instance.model.layers = [ - mocker.MagicMock() for _ in range(3) + mocker.MagicMock() for _ in range(2) ] mock_layer_0 = mock_model_instance.model.layers[0] @@ -124,7 +124,7 @@ class TestAscendRMSNorm(PytestBase): mock_forward_context.addrmsnorm_quant_fusion_enabled = True mock_forward_context.prefetch_mlp_enabled = False mock_forward_context.layer_idx = 0 - mock_forward_context.num_hidden_layers = 3 + mock_forward_context.num_hidden_layers = 2 mock_forward_context.fusion_linear = "gate_up_dense" # Ensure fusion and layer_idx increment are handled correctly @@ -144,32 +144,18 @@ class TestAscendRMSNorm(PytestBase): assert mock_forward_context.fusion_linear == "gate_up_dense" assert mock_forward_context.layer_idx == 1 - mock_forward_context.fusion_linear = "gate_moe" x_out, residual_out = layer.forward_oot(x, residual) assert mock_get_forward_context.call_count == 3 - assert mock_forward_context.fusion_linear == "qkv_moe" + assert mock_forward_context.fusion_linear == "qkv_dense" assert mock_forward_context.layer_idx == 2 x_out, residual_out = layer.forward_oot(x, residual) assert mock_get_forward_context.call_count == 4 - assert mock_forward_context.fusion_linear == "gate_moe" + assert mock_forward_context.fusion_linear == "qkv_dense" assert mock_forward_context.layer_idx == 2 - # last layer returned directly - x_out, residual_out = layer.forward_oot(x, residual) - - assert mock_get_forward_context.call_count == 5 - assert mock_forward_context.fusion_linear == "qkv_moe" - assert mock_forward_context.layer_idx == 3 - - x_out, residual_out = layer.forward_oot(x, residual) - - assert mock_get_forward_context.call_count == 6 - assert mock_forward_context.fusion_linear == "qkv_moe" - assert mock_forward_context.layer_idx == 3 - if __name__ == '__main__': unittest.main() diff --git a/tests/ut/torchair/quantization/test_torchair_w8a8_dynamic.py b/tests/ut/torchair/quantization/test_torchair_w8a8_dynamic.py index 3a98cfc5c..520155d2e 100644 --- a/tests/ut/torchair/quantization/test_torchair_w8a8_dynamic.py +++ b/tests/ut/torchair/quantization/test_torchair_w8a8_dynamic.py @@ -1,6 +1,5 @@ from unittest.mock import MagicMock, patch -import pytest import torch from tests.ut.base import TestBase @@ -17,10 +16,6 @@ class TestAscendW8A8FusedMoEMethod(TestBase): self.hidden_size, dtype=torch.bfloat16) - @pytest.mark.skipif( - True, - reason="fix me", - ) @patch("torch.distributed.all_to_all_single") @patch("torch_npu.npu_moe_re_routing") @patch("torch_npu.npu_grouped_matmul") diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index 209e507f9..93633ae9b 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -156,14 +156,12 @@ def set_ascend_forward_context( # Once the necessary conditions are met, support for MOE models will also be added. from vllm_ascend.quantization.quant_config import AscendQuantConfig addrmsnorm_quant_fusion_enabled = isinstance(vllm_config.quant_config, AscendQuantConfig) and \ - vllm_config.model_config.hf_config.model_type in ["llama", "qwen2", "qwen3", "qwen3_moe"] and \ + vllm_config.model_config.hf_config.model_type in ["llama", "qwen2", "qwen3"] and \ forward_context.layer_idx is not None if addrmsnorm_quant_fusion_enabled: forward_context.model_instance = model_instance forward_context.num_hidden_layers = vllm_config.model_config.hf_config.num_hidden_layers forward_context.fusion_linear = "gate_up_dense" if forward_context.layer_idx == 0 else "qkv_dense" - if vllm_config.model_config.hf_config.model_type == "qwen3_moe": - forward_context.fusion_linear = "gate_moe" if forward_context.layer_idx == 0 else "qkv_moe" forward_context.addrmsnorm_quant_fusion_enabled = addrmsnorm_quant_fusion_enabled if num_tokens is None and attn_metadata is not None: diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 331e5fa02..d289bb457 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -34,8 +34,7 @@ from vllm.v1.kv_cache_interface import AttentionSpec from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, maybe_save_kv_layer_to_connector, wait_for_kv_layer_from_connector) -from vllm_ascend.compilation.acl_graph import (get_graph_params, - update_graph_params_workspaces) +from vllm_ascend.compilation.acl_graph import get_graph_params from vllm_ascend.ops.attention import vanilla_chunked_prefill from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p, nd_to_nz_2d, nd_to_nz_spec) @@ -394,28 +393,13 @@ class AscendAttentionBackendImpl(AttentionImpl): forward_context: ForwardContext = get_forward_context() num_tokens = query.shape[0] if forward_context.capturing: - # Get workspace from cache or calculate it if not present. - workspace = graph_params.workspaces.get(num_tokens) - if workspace is None: - workspace = torch_npu._npu_paged_attention_get_workspace( - query=query, - key_cache=self.key_cache, - value_cache=self.value_cache, - num_kv_heads=self.num_kv_heads, - num_heads=self.num_heads, - scale_value=self.scale, - block_table=attn_metadata.block_tables, - context_lens=attn_metadata.seq_lens, - out=output) - update_graph_params_workspaces(num_tokens, workspace) - - # Handle graph capturing mode stream = torch_npu.npu.current_stream() event = torch.npu.ExternalEvent() event.wait(stream) event.reset(stream) graph_params.events[num_tokens].append(event) + graph_params.attn_params[num_tokens].append(( query, self.key_cache, @@ -429,7 +413,6 @@ class AscendAttentionBackendImpl(AttentionImpl): )) torch.npu.graph_task_group_begin(stream) - torch_npu._npu_paged_attention( query=query, key_cache=self.key_cache, @@ -439,8 +422,7 @@ class AscendAttentionBackendImpl(AttentionImpl): scale_value=self.scale, block_table=attn_metadata.block_tables, context_lens=attn_metadata.seq_lens, - out=output, - workspace=workspace) + out=output) handle = torch.npu.graph_task_group_end(stream) graph_params.handles[num_tokens].append(handle) else: diff --git a/vllm_ascend/compilation/acl_graph.py b/vllm_ascend/compilation/acl_graph.py index 116d38274..8a4180773 100644 --- a/vllm_ascend/compilation/acl_graph.py +++ b/vllm_ascend/compilation/acl_graph.py @@ -215,17 +215,15 @@ def update_attn_params(update_stream, forward_context, runtime_shape): with torch.npu.stream(update_stream): torch.npu.graph_task_update_begin(update_stream, handle) - torch_npu._npu_paged_attention( - query=query, - key_cache=key_cache, - value_cache=value_cache, - num_kv_heads=num_kv_heads, - num_heads=num_heads, - scale_value=scale, - block_table=block_table, - context_lens=seq_lens, - out=output, - workspace=graph_params.workspaces.get(runtime_shape)) + torch_npu._npu_paged_attention(query=query, + key_cache=key_cache, + value_cache=value_cache, + num_kv_heads=num_kv_heads, + num_heads=num_heads, + scale_value=scale, + block_table=block_table, + context_lens=seq_lens, + out=output) torch.npu.graph_task_update_end(update_stream) event.record(update_stream) @@ -258,11 +256,5 @@ def set_graph_params(aclgraph_capture_sizes: set[int]): ) -def update_graph_params_workspaces(num_tokens: int, workspace: int): - global _graph_params - if _graph_params is not None: - _graph_params.workspaces[num_tokens] = workspace - - def get_graph_params(): return _graph_params diff --git a/vllm_ascend/ops/layernorm.py b/vllm_ascend/ops/layernorm.py index 344a8dcc0..3dfca5355 100644 --- a/vllm_ascend/ops/layernorm.py +++ b/vllm_ascend/ops/layernorm.py @@ -15,10 +15,9 @@ # This file is a part of the vllm-ascend project. # -from typing import Optional, Tuple, Union +from typing import Optional, Tuple, Union, cast import torch -from vllm.config import get_current_vllm_config from vllm.forward_context import get_forward_context from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm @@ -28,7 +27,6 @@ def _addrmsnorm_forward_oot( x: torch.Tensor, residual: torch.Tensor, layer: Optional[torch.nn.Module] = None, - bias: Optional[torch.nn.Parameter] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: import torch_npu @@ -41,7 +39,6 @@ def _addrmsnorm_forward_oot( self.weight, layer.aclnn_input_scale, layer.aclnn_input_offset, - beta=bias, epsilon=self.variance_epsilon) else: if is_310p(): @@ -53,31 +50,12 @@ def _addrmsnorm_forward_oot( else: x, _, residual = torch_npu.npu_add_rms_norm( x, residual, self.weight, self.variance_epsilon) - if bias is not None: - x.add_(bias) torch.ops.vllm.maybe_wait_prefetch_done(x) return x, residual class AscendRMSNorm(RMSNorm): - def __init__( - self, - hidden_size: int, - eps: float = 1e-6, - var_hidden_size: Optional[int] = None, - has_weight: bool = True, - dtype: Optional[torch.dtype] = None, - ) -> None: - super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype) - vllm_config = get_current_vllm_config() - self.bias = None - # quantization with anti_method m4 will generate none-zero norm bias - if vllm_config is not None and vllm_config.quant_config is not None and \ - any("norm.bias" in name for name in vllm_config.quant_config.quant_description.keys()): - self.bias = torch.nn.Parameter(torch.zeros(hidden_size), - requires_grad=False) - def forward_oot( self, x: torch.Tensor, @@ -89,13 +67,10 @@ class AscendRMSNorm(RMSNorm): residual = torch.ops.vllm.maybe_chunk_residual(x, residual) assert x.size(0) == residual.size(0) x, residual = _addrmsnorm_forward_oot( - self, x, residual, self.next_need_quant_fusion_linear, - self.bias) + self, x, residual, self.next_need_quant_fusion_linear) return x, residual x, residual = torch_npu.npu_rms_norm(x, self.weight, self.variance_epsilon) - if self.bias is not None: - x.add_(self.bias) return x @property @@ -125,13 +100,6 @@ class AscendRMSNorm(RMSNorm): # does not need to be repeated if not forward_context.prefetch_mlp_enabled: forward_context.layer_idx += 1 - elif fusion_linear == "qkv_moe": - next_linear = model_instance.model.layers[ - layer_idx].self_attn.qkv_proj - forward_context.fusion_linear = "gate_moe" - elif fusion_linear == "gate_moe": - forward_context.fusion_linear = "qkv_moe" - forward_context.layer_idx += 1 from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod if next_linear is not None and \ not isinstance(next_linear.quant_method.quant_method, AscendW8A8LinearMethod): @@ -139,6 +107,31 @@ class AscendRMSNorm(RMSNorm): return next_linear +class AscendQuantRMSNorm(AscendRMSNorm): + + def __init__( + self, + hidden_size: int, + eps: float = 1e-6, + var_hidden_size: Optional[int] = None, + has_weight: bool = True, + dtype: Optional[torch.dtype] = None, + ) -> None: + super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype) + self.bias = torch.nn.Parameter(torch.zeros(hidden_size), + requires_grad=False) + + def forward_oot( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if residual is not None: + x, residual = super().forward_oot(x, residual) + return x.add_(self.bias), residual + return cast(torch.Tensor, super().forward_oot(x)).add_(self.bias) + + class AscendGemmaRMSNorm(GemmaRMSNorm): def forward_oot( diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 17f2edae6..62faa9397 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -501,7 +501,8 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None): from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul from vllm_ascend.ops.common_fused_moe import (AscendFusedMoE, AscendSharedFusedMoE) - from vllm_ascend.ops.layernorm import AscendGemmaRMSNorm, AscendRMSNorm + from vllm_ascend.ops.layernorm import (AscendGemmaRMSNorm, + AscendQuantRMSNorm, AscendRMSNorm) from vllm_ascend.ops.linear import (AscendColumnParallelLinear, AscendMergedColumnParallelLinear, AscendQKVParallelLinear, @@ -532,6 +533,11 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None): "MultiHeadLatentAttention": AscendMultiHeadLatentAttention, } + if vllm_config is not None and \ + vllm_config.quant_config is not None and \ + any("norm.bias" in name for name in vllm_config.quant_config.quant_description.keys()): + REGISTERED_ASCEND_OPS["RMSNorm"] = AscendQuantRMSNorm + for name, op_cls in REGISTERED_ASCEND_OPS.items(): CustomOp.register_oot(_decorated_op_cls=op_cls, name=name)