Revert PTA upgrade PR (#3352)

we notice that torch npu 0919 doesn't work. This PR revert related
change which rely on 0919 version.
Revert PR: #3295  #3205  #3102 

Related: #3353

- vLLM version: v0.11.0
This commit is contained in:
wangxiyuan
2025-10-10 14:09:53 +08:00
committed by GitHub
parent 601a37aeff
commit ba19dd3183
15 changed files with 57 additions and 312 deletions

View File

@ -173,7 +173,6 @@ jobs:
if: ${{ inputs.type == 'full' }}
run: |
pytest -sv tests/e2e/multicard/test_data_parallel.py
pytest -sv tests/e2e/multicard/test_full_graph_mode.py
pytest -sv tests/e2e/multicard/test_expert_parallel.py
# external_launcher test is not stable enough. Fix it later
# pytest -sv tests/e2e/multicard/test_external_launcher.py

View File

@ -43,7 +43,7 @@ By using vLLM Ascend plugin, popular open-source models, including Transformer-l
- Software:
* Python >= 3.9, < 3.12
* CANN >= 8.2.rc1 (Ascend HDK version refers to [here](https://www.hiascend.com/document/detail/zh/canncommercial/82RC1/releasenote/releasenote_0000.html))
* PyTorch >= 2.7.1, torch-npu >= 2.7.1.dev20250919
* PyTorch >= 2.7.1, torch-npu >= 2.7.1.dev20250724
* vLLM (the same version as vllm-ascend)
## Getting Started

View File

@ -44,7 +44,7 @@ vLLM 昇腾插件 (`vllm-ascend`) 是一个由社区维护的让vLLM在Ascend NP
- 软件:
* Python >= 3.9, < 3.12
* CANN >= 8.2.rc1 (Ascend HDK 版本参考[这里](https://www.hiascend.com/document/detail/zh/canncommercial/82RC1/releasenote/releasenote_0000.html))
* PyTorch >= 2.7.1, torch-npu >= 2.7.1.dev20250919
* PyTorch >= 2.7.1, torch-npu >= 2.7.1.dev20250724
* vLLM (与vllm-ascend版本一致)
## 开始使用

View File

@ -13,7 +13,7 @@ This document describes how to install vllm-ascend manually.
|---------------|----------------------------------|-------------------------------------------|
| Ascend HDK | Refer to [here](https://www.hiascend.com/document/detail/zh/canncommercial/82RC1/releasenote/releasenote_0000.html) | Required for CANN |
| CANN | >= 8.2.RC1 | Required for vllm-ascend and torch-npu |
| torch-npu | >= 2.7.1.dev20250919 | Required for vllm-ascend, No need to install manually, it will be auto installed in below steps |
| torch-npu | >= 2.7.1.dev20250724 | Required for vllm-ascend, No need to install manually, it will be auto installed in below steps |
| torch | >= 2.7.1 | Required for torch-npu and vllm |
You have 2 way to install:

View File

@ -12,7 +12,7 @@ requires = [
"scipy",
"setuptools>=64",
"setuptools-scm>=8",
"torch-npu==2.7.1.dev20250919",
"torch-npu==2.7.1.dev20250724",
"torch>=2.7.1",
"torchvision",
"wheel",

View File

@ -24,4 +24,4 @@ numba
# Install torch_npu
--pre
--extra-index-url https://mirrors.huaweicloud.com/ascend/repos/pypi
torch-npu==2.7.1.dev20250919
torch-npu==2.7.1.dev20250724

View File

@ -1,103 +0,0 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py
#
"""Compare the short outputs of HF and vLLM when using greedy sampling.
Run `pytest tests/e2e/multicard/test_qwen3_moe.py`.
"""
import os
from vllm import SamplingParams
from tests.e2e.conftest import VllmRunner
from tests.e2e.model_utils import check_outputs_equal
def test_models_distributed_Qwen3_MOE_TP2_WITH_FULLGRAPH():
if 'HCCL_OP_EXPANSION_MODE' in os.environ:
del os.environ['HCCL_OP_EXPANSION_MODE']
prompts = [
('Solve the following math problem step by step.'
'The last line of your response should be of the form Answer: '
'$Answer (without quotes) where $Answer is the answer to the problem.\n\n'
'In triangle $ABC$, $\\sin \\angle A = \\frac{4}{5}$ and $\\angle A < 90^\\circ$. Let $D$'
'be a point outside triangle $ABC$ such that $\\angle BAD = \\angle DAC$,'
'$\\angle BDC = 90^\\circ$. Suppose $AD = 1$ and $\\frac{BD}{CD} = \\frac{3}{2}$.'
'If $AB + AC$ can be expressed in the form $\\frac{a\\sqrt{b}}{c}$,'
'where $a, b, c$ are pairwise relatively prime integers, find $a + b + c$.'
),
('Solve the following math problem step by step.'
'The last line of your response should be of the form Answer: '
'$Answer (without quotes) where $Answer is the answer to the problem.\n\n'
'Let $ABCD$ be a unit square in the plane. Points $X$ and $Y$ are chosen'
'independently and uniformly at random on the perimeter of $ABCD$.'
'If the expected value of the area of triangle $\\triangle AXY$'
'can be expressed as $\\frac{m}{n}$, for relatively prime positive'
'integers $m$ and $n$, compute $m+n$.'),
('Solve the following math problem step by step.'
'The last line of your response should be of the form Answer: '
'$Answer (without quotes) where $Answer is the answer to the problem.\n\n'
'Let $a, b, c$ be distinct numbers such that the equations $x^2 + ax + 1 = 0$'
'and $x^2 + bx + c = 0$ have a common real root, and the equations $x^2 + x + a = 0$'
'and $x^2 + cx + b = 0$ also have a common real root.'
'Compute the sum $a + b + c$.')
]
model = "Qwen/Qwen3-30B-A3B"
sampling_params = SamplingParams(max_tokens=5,
n=1,
temperature=0.0,
top_p=1.0,
top_k=1)
with VllmRunner(model,
max_model_len=1024,
tensor_parallel_size=2,
enforce_eager=False,
gpu_memory_utilization=0.95,
compilation_config={
"cudagraph_capture_sizes":
[4, 8, 12, 16, 24, 32, 40, 48],
"cudagraph_mode": "FULL_DECODE_ONLY"
}) as runner:
vllm_fullgraph_outputs = runner.model.generate(prompts,
sampling_params)
with VllmRunner(
model,
max_model_len=1024,
tensor_parallel_size=2,
enforce_eager=True,
gpu_memory_utilization=0.95,
) as runner:
vllm_eager_outputs = runner.model.generate(prompts, sampling_params)
vllm_fullgraph_outputs_list = []
for output in vllm_fullgraph_outputs:
vllm_fullgraph_outputs_list.append(
(output.outputs[0].index, output.outputs[0].text))
vllm_eager_outputs_list = []
for output in vllm_eager_outputs:
vllm_eager_outputs_list.append(
(output.outputs[0].index, output.outputs[0].text))
check_outputs_equal(
outputs_0_lst=vllm_eager_outputs_list,
outputs_1_lst=vllm_fullgraph_outputs_list,
name_0="vllm_eager_outputs",
name_1="vllm_fullgraph_outputs",
)

View File

@ -405,109 +405,6 @@ class TestAscendAttentionBackendImpl(TestBase):
mock_paged_attention.assert_called_once()
assert output.shape == (10, 8 * 64)
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
@patch('vllm_ascend.attention.attention_v1.get_graph_params')
@patch('torch_npu._npu_reshape_and_cache')
@patch('torch_npu._npu_paged_attention')
@patch('torch.npu.graph_task_group_end')
@patch('torch.npu.graph_task_group_begin')
@patch('torch.npu.ExternalEvent')
@patch('torch_npu.npu.current_stream')
def test_paged_attention_with_existing_workspace(
self,
mock_get_forward_context,
mock_get_graph_params,
mock_npu_reshape_and_cache,
mock_paged_attention,
mock_graph_begin,
mock_graph_end,
mock_external_event_class,
mock_current_stream,
):
graph_params = MagicMock()
attn_metadata = MagicMock()
num_tokens = 10
graph_params.workspaces = {num_tokens: 10}
graph_params.events = {num_tokens: []}
graph_params.attn_params = {num_tokens: []}
graph_params.handles = {num_tokens: []}
query = torch.randn(2, 5, 8) # [batch_size, seq_len, hidden_size]
key_cache = MagicMock()
value_cache = MagicMock()
num_kv_heads = 4
num_heads = 8
scale = 0.1
output = torch.randn(2, 5, 8)
self_obj = MagicMock()
self_obj.key_cache = key_cache
self_obj.value_cache = value_cache
self_obj.num_kv_heads = num_kv_heads
self_obj.num_heads = num_heads
self_obj.scale = scale
mock_stream = MagicMock()
mock_current_stream.return_value = mock_stream
mock_event_instance = MagicMock()
mock_external_event_class.return_value = mock_event_instance
mock_handle = MagicMock()
mock_graph_end.return_value = mock_handle
workspace = graph_params.workspaces.get(num_tokens)
self.assertEqual(workspace, 10)
# 2. Handle graph capturing mode
stream = mock_current_stream()
event = mock_external_event_class()
event.wait(stream)
event.reset(stream)
graph_params.events[num_tokens].append(event)
graph_params.attn_params[num_tokens].append((
query,
self_obj.key_cache,
self_obj.value_cache,
self_obj.num_kv_heads,
self_obj.num_heads,
self_obj.scale,
attn_metadata.block_tables,
attn_metadata.seq_lens,
output,
))
mock_event_instance.wait.assert_called_once_with(mock_stream)
mock_event_instance.reset.assert_called_once_with(mock_stream)
self.assertEqual(len(graph_params.events[num_tokens]), 1)
self.assertEqual(len(graph_params.attn_params[num_tokens]), 1)
query = torch.randn(10, 8 * 64)
key = torch.randn(10, 8 * 64)
value = torch.randn(10, 8 * 64)
kv_cache = torch.empty(2, 5, 128, 8, 64)
metadata = self.attn_metadata
metadata.attn_state = AscendAttentionState.DecodeOnly
metadata.seq_lens = torch.tensor([10])
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
metadata.num_actual_tokens = 10
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
layer = self.layer_no_quant
mock_get_forward_context.return_value = MagicMock(capturing=True)
mock_get_graph_params.return_value = graph_params
output = self.impl.forward(layer,
query,
key,
value,
kv_cache,
metadata,
trace_flag=False)
mock_paged_attention.assert_called_once()
self.assertEqual(len(graph_params.handles[num_tokens]), 0)
@patch('torch_npu._npu_reshape_and_cache')
@patch('torch_npu.npu_fused_infer_attention_score')
def test_forward_decode_only_swa(self, mock_fused_infer_attention_score,

View File

@ -24,7 +24,7 @@ def mock_add_rms_norm(x, residual, weight, eps):
def mock_add_rms_norm_quant(x, residual, weight, quant_scale, quant_offset,
beta, epsilon):
epsilon):
x_out = 2 * x
residual_out = 2 * residual
x_out_quant = x_out.to(torch.int8)
@ -94,7 +94,7 @@ class TestAscendRMSNorm(PytestBase):
mock_model_instance = mocker.MagicMock()
mock_forward_context.model_instance = mock_model_instance
mock_model_instance.model.layers = [
mocker.MagicMock() for _ in range(3)
mocker.MagicMock() for _ in range(2)
]
mock_layer_0 = mock_model_instance.model.layers[0]
@ -124,7 +124,7 @@ class TestAscendRMSNorm(PytestBase):
mock_forward_context.addrmsnorm_quant_fusion_enabled = True
mock_forward_context.prefetch_mlp_enabled = False
mock_forward_context.layer_idx = 0
mock_forward_context.num_hidden_layers = 3
mock_forward_context.num_hidden_layers = 2
mock_forward_context.fusion_linear = "gate_up_dense"
# Ensure fusion and layer_idx increment are handled correctly
@ -144,32 +144,18 @@ class TestAscendRMSNorm(PytestBase):
assert mock_forward_context.fusion_linear == "gate_up_dense"
assert mock_forward_context.layer_idx == 1
mock_forward_context.fusion_linear = "gate_moe"
x_out, residual_out = layer.forward_oot(x, residual)
assert mock_get_forward_context.call_count == 3
assert mock_forward_context.fusion_linear == "qkv_moe"
assert mock_forward_context.fusion_linear == "qkv_dense"
assert mock_forward_context.layer_idx == 2
x_out, residual_out = layer.forward_oot(x, residual)
assert mock_get_forward_context.call_count == 4
assert mock_forward_context.fusion_linear == "gate_moe"
assert mock_forward_context.fusion_linear == "qkv_dense"
assert mock_forward_context.layer_idx == 2
# last layer returned directly
x_out, residual_out = layer.forward_oot(x, residual)
assert mock_get_forward_context.call_count == 5
assert mock_forward_context.fusion_linear == "qkv_moe"
assert mock_forward_context.layer_idx == 3
x_out, residual_out = layer.forward_oot(x, residual)
assert mock_get_forward_context.call_count == 6
assert mock_forward_context.fusion_linear == "qkv_moe"
assert mock_forward_context.layer_idx == 3
if __name__ == '__main__':
unittest.main()

View File

@ -1,6 +1,5 @@
from unittest.mock import MagicMock, patch
import pytest
import torch
from tests.ut.base import TestBase
@ -17,10 +16,6 @@ class TestAscendW8A8FusedMoEMethod(TestBase):
self.hidden_size,
dtype=torch.bfloat16)
@pytest.mark.skipif(
True,
reason="fix me",
)
@patch("torch.distributed.all_to_all_single")
@patch("torch_npu.npu_moe_re_routing")
@patch("torch_npu.npu_grouped_matmul")

View File

@ -156,14 +156,12 @@ def set_ascend_forward_context(
# Once the necessary conditions are met, support for MOE models will also be added.
from vllm_ascend.quantization.quant_config import AscendQuantConfig
addrmsnorm_quant_fusion_enabled = isinstance(vllm_config.quant_config, AscendQuantConfig) and \
vllm_config.model_config.hf_config.model_type in ["llama", "qwen2", "qwen3", "qwen3_moe"] and \
vllm_config.model_config.hf_config.model_type in ["llama", "qwen2", "qwen3"] and \
forward_context.layer_idx is not None
if addrmsnorm_quant_fusion_enabled:
forward_context.model_instance = model_instance
forward_context.num_hidden_layers = vllm_config.model_config.hf_config.num_hidden_layers
forward_context.fusion_linear = "gate_up_dense" if forward_context.layer_idx == 0 else "qkv_dense"
if vllm_config.model_config.hf_config.model_type == "qwen3_moe":
forward_context.fusion_linear = "gate_moe" if forward_context.layer_idx == 0 else "qkv_moe"
forward_context.addrmsnorm_quant_fusion_enabled = addrmsnorm_quant_fusion_enabled
if num_tokens is None and attn_metadata is not None:

View File

@ -34,8 +34,7 @@ from vllm.v1.kv_cache_interface import AttentionSpec
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
maybe_save_kv_layer_to_connector,
wait_for_kv_layer_from_connector)
from vllm_ascend.compilation.acl_graph import (get_graph_params,
update_graph_params_workspaces)
from vllm_ascend.compilation.acl_graph import get_graph_params
from vllm_ascend.ops.attention import vanilla_chunked_prefill
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
nd_to_nz_2d, nd_to_nz_spec)
@ -394,28 +393,13 @@ class AscendAttentionBackendImpl(AttentionImpl):
forward_context: ForwardContext = get_forward_context()
num_tokens = query.shape[0]
if forward_context.capturing:
# Get workspace from cache or calculate it if not present.
workspace = graph_params.workspaces.get(num_tokens)
if workspace is None:
workspace = torch_npu._npu_paged_attention_get_workspace(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=attn_metadata.block_tables,
context_lens=attn_metadata.seq_lens,
out=output)
update_graph_params_workspaces(num_tokens, workspace)
# Handle graph capturing mode
stream = torch_npu.npu.current_stream()
event = torch.npu.ExternalEvent()
event.wait(stream)
event.reset(stream)
graph_params.events[num_tokens].append(event)
graph_params.attn_params[num_tokens].append((
query,
self.key_cache,
@ -429,7 +413,6 @@ class AscendAttentionBackendImpl(AttentionImpl):
))
torch.npu.graph_task_group_begin(stream)
torch_npu._npu_paged_attention(
query=query,
key_cache=self.key_cache,
@ -439,8 +422,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
scale_value=self.scale,
block_table=attn_metadata.block_tables,
context_lens=attn_metadata.seq_lens,
out=output,
workspace=workspace)
out=output)
handle = torch.npu.graph_task_group_end(stream)
graph_params.handles[num_tokens].append(handle)
else:

View File

@ -215,17 +215,15 @@ def update_attn_params(update_stream, forward_context, runtime_shape):
with torch.npu.stream(update_stream):
torch.npu.graph_task_update_begin(update_stream, handle)
torch_npu._npu_paged_attention(
query=query,
key_cache=key_cache,
value_cache=value_cache,
num_kv_heads=num_kv_heads,
num_heads=num_heads,
scale_value=scale,
block_table=block_table,
context_lens=seq_lens,
out=output,
workspace=graph_params.workspaces.get(runtime_shape))
torch_npu._npu_paged_attention(query=query,
key_cache=key_cache,
value_cache=value_cache,
num_kv_heads=num_kv_heads,
num_heads=num_heads,
scale_value=scale,
block_table=block_table,
context_lens=seq_lens,
out=output)
torch.npu.graph_task_update_end(update_stream)
event.record(update_stream)
@ -258,11 +256,5 @@ def set_graph_params(aclgraph_capture_sizes: set[int]):
)
def update_graph_params_workspaces(num_tokens: int, workspace: int):
global _graph_params
if _graph_params is not None:
_graph_params.workspaces[num_tokens] = workspace
def get_graph_params():
return _graph_params

View File

@ -15,10 +15,9 @@
# This file is a part of the vllm-ascend project.
#
from typing import Optional, Tuple, Union
from typing import Optional, Tuple, Union, cast
import torch
from vllm.config import get_current_vllm_config
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
@ -28,7 +27,6 @@ def _addrmsnorm_forward_oot(
x: torch.Tensor,
residual: torch.Tensor,
layer: Optional[torch.nn.Module] = None,
bias: Optional[torch.nn.Parameter] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
import torch_npu
@ -41,7 +39,6 @@ def _addrmsnorm_forward_oot(
self.weight,
layer.aclnn_input_scale,
layer.aclnn_input_offset,
beta=bias,
epsilon=self.variance_epsilon)
else:
if is_310p():
@ -53,31 +50,12 @@ def _addrmsnorm_forward_oot(
else:
x, _, residual = torch_npu.npu_add_rms_norm(
x, residual, self.weight, self.variance_epsilon)
if bias is not None:
x.add_(bias)
torch.ops.vllm.maybe_wait_prefetch_done(x)
return x, residual
class AscendRMSNorm(RMSNorm):
def __init__(
self,
hidden_size: int,
eps: float = 1e-6,
var_hidden_size: Optional[int] = None,
has_weight: bool = True,
dtype: Optional[torch.dtype] = None,
) -> None:
super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype)
vllm_config = get_current_vllm_config()
self.bias = None
# quantization with anti_method m4 will generate none-zero norm bias
if vllm_config is not None and vllm_config.quant_config is not None and \
any("norm.bias" in name for name in vllm_config.quant_config.quant_description.keys()):
self.bias = torch.nn.Parameter(torch.zeros(hidden_size),
requires_grad=False)
def forward_oot(
self,
x: torch.Tensor,
@ -89,13 +67,10 @@ class AscendRMSNorm(RMSNorm):
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
assert x.size(0) == residual.size(0)
x, residual = _addrmsnorm_forward_oot(
self, x, residual, self.next_need_quant_fusion_linear,
self.bias)
self, x, residual, self.next_need_quant_fusion_linear)
return x, residual
x, residual = torch_npu.npu_rms_norm(x, self.weight,
self.variance_epsilon)
if self.bias is not None:
x.add_(self.bias)
return x
@property
@ -125,13 +100,6 @@ class AscendRMSNorm(RMSNorm):
# does not need to be repeated
if not forward_context.prefetch_mlp_enabled:
forward_context.layer_idx += 1
elif fusion_linear == "qkv_moe":
next_linear = model_instance.model.layers[
layer_idx].self_attn.qkv_proj
forward_context.fusion_linear = "gate_moe"
elif fusion_linear == "gate_moe":
forward_context.fusion_linear = "qkv_moe"
forward_context.layer_idx += 1
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
if next_linear is not None and \
not isinstance(next_linear.quant_method.quant_method, AscendW8A8LinearMethod):
@ -139,6 +107,31 @@ class AscendRMSNorm(RMSNorm):
return next_linear
class AscendQuantRMSNorm(AscendRMSNorm):
def __init__(
self,
hidden_size: int,
eps: float = 1e-6,
var_hidden_size: Optional[int] = None,
has_weight: bool = True,
dtype: Optional[torch.dtype] = None,
) -> None:
super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype)
self.bias = torch.nn.Parameter(torch.zeros(hidden_size),
requires_grad=False)
def forward_oot(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if residual is not None:
x, residual = super().forward_oot(x, residual)
return x.add_(self.bias), residual
return cast(torch.Tensor, super().forward_oot(x)).add_(self.bias)
class AscendGemmaRMSNorm(GemmaRMSNorm):
def forward_oot(

View File

@ -501,7 +501,8 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul
from vllm_ascend.ops.common_fused_moe import (AscendFusedMoE,
AscendSharedFusedMoE)
from vllm_ascend.ops.layernorm import AscendGemmaRMSNorm, AscendRMSNorm
from vllm_ascend.ops.layernorm import (AscendGemmaRMSNorm,
AscendQuantRMSNorm, AscendRMSNorm)
from vllm_ascend.ops.linear import (AscendColumnParallelLinear,
AscendMergedColumnParallelLinear,
AscendQKVParallelLinear,
@ -532,6 +533,11 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
"MultiHeadLatentAttention": AscendMultiHeadLatentAttention,
}
if vllm_config is not None and \
vllm_config.quant_config is not None and \
any("norm.bias" in name for name in vllm_config.quant_config.quant_description.keys()):
REGISTERED_ASCEND_OPS["RMSNorm"] = AscendQuantRMSNorm
for name, op_cls in REGISTERED_ASCEND_OPS.items():
CustomOp.register_oot(_decorated_op_cls=op_cls, name=name)