mirror of
https://github.com/vllm-project/vllm-ascend.git
synced 2025-10-20 13:43:53 +08:00
Revert PTA upgrade PR (#3352)
we notice that torch npu 0919 doesn't work. This PR revert related change which rely on 0919 version. Revert PR: #3295 #3205 #3102 Related: #3353 - vLLM version: v0.11.0
This commit is contained in:
1
.github/workflows/_e2e_test.yaml
vendored
1
.github/workflows/_e2e_test.yaml
vendored
@ -173,7 +173,6 @@ jobs:
|
|||||||
if: ${{ inputs.type == 'full' }}
|
if: ${{ inputs.type == 'full' }}
|
||||||
run: |
|
run: |
|
||||||
pytest -sv tests/e2e/multicard/test_data_parallel.py
|
pytest -sv tests/e2e/multicard/test_data_parallel.py
|
||||||
pytest -sv tests/e2e/multicard/test_full_graph_mode.py
|
|
||||||
pytest -sv tests/e2e/multicard/test_expert_parallel.py
|
pytest -sv tests/e2e/multicard/test_expert_parallel.py
|
||||||
# external_launcher test is not stable enough. Fix it later
|
# external_launcher test is not stable enough. Fix it later
|
||||||
# pytest -sv tests/e2e/multicard/test_external_launcher.py
|
# pytest -sv tests/e2e/multicard/test_external_launcher.py
|
||||||
|
@ -43,7 +43,7 @@ By using vLLM Ascend plugin, popular open-source models, including Transformer-l
|
|||||||
- Software:
|
- Software:
|
||||||
* Python >= 3.9, < 3.12
|
* Python >= 3.9, < 3.12
|
||||||
* CANN >= 8.2.rc1 (Ascend HDK version refers to [here](https://www.hiascend.com/document/detail/zh/canncommercial/82RC1/releasenote/releasenote_0000.html))
|
* CANN >= 8.2.rc1 (Ascend HDK version refers to [here](https://www.hiascend.com/document/detail/zh/canncommercial/82RC1/releasenote/releasenote_0000.html))
|
||||||
* PyTorch >= 2.7.1, torch-npu >= 2.7.1.dev20250919
|
* PyTorch >= 2.7.1, torch-npu >= 2.7.1.dev20250724
|
||||||
* vLLM (the same version as vllm-ascend)
|
* vLLM (the same version as vllm-ascend)
|
||||||
|
|
||||||
## Getting Started
|
## Getting Started
|
||||||
|
@ -44,7 +44,7 @@ vLLM 昇腾插件 (`vllm-ascend`) 是一个由社区维护的让vLLM在Ascend NP
|
|||||||
- 软件:
|
- 软件:
|
||||||
* Python >= 3.9, < 3.12
|
* Python >= 3.9, < 3.12
|
||||||
* CANN >= 8.2.rc1 (Ascend HDK 版本参考[这里](https://www.hiascend.com/document/detail/zh/canncommercial/82RC1/releasenote/releasenote_0000.html))
|
* CANN >= 8.2.rc1 (Ascend HDK 版本参考[这里](https://www.hiascend.com/document/detail/zh/canncommercial/82RC1/releasenote/releasenote_0000.html))
|
||||||
* PyTorch >= 2.7.1, torch-npu >= 2.7.1.dev20250919
|
* PyTorch >= 2.7.1, torch-npu >= 2.7.1.dev20250724
|
||||||
* vLLM (与vllm-ascend版本一致)
|
* vLLM (与vllm-ascend版本一致)
|
||||||
|
|
||||||
## 开始使用
|
## 开始使用
|
||||||
|
@ -13,7 +13,7 @@ This document describes how to install vllm-ascend manually.
|
|||||||
|---------------|----------------------------------|-------------------------------------------|
|
|---------------|----------------------------------|-------------------------------------------|
|
||||||
| Ascend HDK | Refer to [here](https://www.hiascend.com/document/detail/zh/canncommercial/82RC1/releasenote/releasenote_0000.html) | Required for CANN |
|
| Ascend HDK | Refer to [here](https://www.hiascend.com/document/detail/zh/canncommercial/82RC1/releasenote/releasenote_0000.html) | Required for CANN |
|
||||||
| CANN | >= 8.2.RC1 | Required for vllm-ascend and torch-npu |
|
| CANN | >= 8.2.RC1 | Required for vllm-ascend and torch-npu |
|
||||||
| torch-npu | >= 2.7.1.dev20250919 | Required for vllm-ascend, No need to install manually, it will be auto installed in below steps |
|
| torch-npu | >= 2.7.1.dev20250724 | Required for vllm-ascend, No need to install manually, it will be auto installed in below steps |
|
||||||
| torch | >= 2.7.1 | Required for torch-npu and vllm |
|
| torch | >= 2.7.1 | Required for torch-npu and vllm |
|
||||||
|
|
||||||
You have 2 way to install:
|
You have 2 way to install:
|
||||||
|
@ -12,7 +12,7 @@ requires = [
|
|||||||
"scipy",
|
"scipy",
|
||||||
"setuptools>=64",
|
"setuptools>=64",
|
||||||
"setuptools-scm>=8",
|
"setuptools-scm>=8",
|
||||||
"torch-npu==2.7.1.dev20250919",
|
"torch-npu==2.7.1.dev20250724",
|
||||||
"torch>=2.7.1",
|
"torch>=2.7.1",
|
||||||
"torchvision",
|
"torchvision",
|
||||||
"wheel",
|
"wheel",
|
||||||
|
@ -24,4 +24,4 @@ numba
|
|||||||
# Install torch_npu
|
# Install torch_npu
|
||||||
--pre
|
--pre
|
||||||
--extra-index-url https://mirrors.huaweicloud.com/ascend/repos/pypi
|
--extra-index-url https://mirrors.huaweicloud.com/ascend/repos/pypi
|
||||||
torch-npu==2.7.1.dev20250919
|
torch-npu==2.7.1.dev20250724
|
||||||
|
@ -1,103 +0,0 @@
|
|||||||
#
|
|
||||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
||||||
# Copyright 2023 The vLLM team.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# This file is a part of the vllm-ascend project.
|
|
||||||
# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py
|
|
||||||
#
|
|
||||||
"""Compare the short outputs of HF and vLLM when using greedy sampling.
|
|
||||||
|
|
||||||
Run `pytest tests/e2e/multicard/test_qwen3_moe.py`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
from vllm import SamplingParams
|
|
||||||
|
|
||||||
from tests.e2e.conftest import VllmRunner
|
|
||||||
from tests.e2e.model_utils import check_outputs_equal
|
|
||||||
|
|
||||||
|
|
||||||
def test_models_distributed_Qwen3_MOE_TP2_WITH_FULLGRAPH():
|
|
||||||
if 'HCCL_OP_EXPANSION_MODE' in os.environ:
|
|
||||||
del os.environ['HCCL_OP_EXPANSION_MODE']
|
|
||||||
prompts = [
|
|
||||||
('Solve the following math problem step by step.'
|
|
||||||
'The last line of your response should be of the form Answer: '
|
|
||||||
'$Answer (without quotes) where $Answer is the answer to the problem.\n\n'
|
|
||||||
'In triangle $ABC$, $\\sin \\angle A = \\frac{4}{5}$ and $\\angle A < 90^\\circ$. Let $D$'
|
|
||||||
'be a point outside triangle $ABC$ such that $\\angle BAD = \\angle DAC$,'
|
|
||||||
'$\\angle BDC = 90^\\circ$. Suppose $AD = 1$ and $\\frac{BD}{CD} = \\frac{3}{2}$.'
|
|
||||||
'If $AB + AC$ can be expressed in the form $\\frac{a\\sqrt{b}}{c}$,'
|
|
||||||
'where $a, b, c$ are pairwise relatively prime integers, find $a + b + c$.'
|
|
||||||
),
|
|
||||||
('Solve the following math problem step by step.'
|
|
||||||
'The last line of your response should be of the form Answer: '
|
|
||||||
'$Answer (without quotes) where $Answer is the answer to the problem.\n\n'
|
|
||||||
'Let $ABCD$ be a unit square in the plane. Points $X$ and $Y$ are chosen'
|
|
||||||
'independently and uniformly at random on the perimeter of $ABCD$.'
|
|
||||||
'If the expected value of the area of triangle $\\triangle AXY$'
|
|
||||||
'can be expressed as $\\frac{m}{n}$, for relatively prime positive'
|
|
||||||
'integers $m$ and $n$, compute $m+n$.'),
|
|
||||||
('Solve the following math problem step by step.'
|
|
||||||
'The last line of your response should be of the form Answer: '
|
|
||||||
'$Answer (without quotes) where $Answer is the answer to the problem.\n\n'
|
|
||||||
'Let $a, b, c$ be distinct numbers such that the equations $x^2 + ax + 1 = 0$'
|
|
||||||
'and $x^2 + bx + c = 0$ have a common real root, and the equations $x^2 + x + a = 0$'
|
|
||||||
'and $x^2 + cx + b = 0$ also have a common real root.'
|
|
||||||
'Compute the sum $a + b + c$.')
|
|
||||||
]
|
|
||||||
model = "Qwen/Qwen3-30B-A3B"
|
|
||||||
sampling_params = SamplingParams(max_tokens=5,
|
|
||||||
n=1,
|
|
||||||
temperature=0.0,
|
|
||||||
top_p=1.0,
|
|
||||||
top_k=1)
|
|
||||||
with VllmRunner(model,
|
|
||||||
max_model_len=1024,
|
|
||||||
tensor_parallel_size=2,
|
|
||||||
enforce_eager=False,
|
|
||||||
gpu_memory_utilization=0.95,
|
|
||||||
compilation_config={
|
|
||||||
"cudagraph_capture_sizes":
|
|
||||||
[4, 8, 12, 16, 24, 32, 40, 48],
|
|
||||||
"cudagraph_mode": "FULL_DECODE_ONLY"
|
|
||||||
}) as runner:
|
|
||||||
vllm_fullgraph_outputs = runner.model.generate(prompts,
|
|
||||||
sampling_params)
|
|
||||||
with VllmRunner(
|
|
||||||
model,
|
|
||||||
max_model_len=1024,
|
|
||||||
tensor_parallel_size=2,
|
|
||||||
enforce_eager=True,
|
|
||||||
gpu_memory_utilization=0.95,
|
|
||||||
) as runner:
|
|
||||||
vllm_eager_outputs = runner.model.generate(prompts, sampling_params)
|
|
||||||
|
|
||||||
vllm_fullgraph_outputs_list = []
|
|
||||||
for output in vllm_fullgraph_outputs:
|
|
||||||
vllm_fullgraph_outputs_list.append(
|
|
||||||
(output.outputs[0].index, output.outputs[0].text))
|
|
||||||
|
|
||||||
vllm_eager_outputs_list = []
|
|
||||||
for output in vllm_eager_outputs:
|
|
||||||
vllm_eager_outputs_list.append(
|
|
||||||
(output.outputs[0].index, output.outputs[0].text))
|
|
||||||
|
|
||||||
check_outputs_equal(
|
|
||||||
outputs_0_lst=vllm_eager_outputs_list,
|
|
||||||
outputs_1_lst=vllm_fullgraph_outputs_list,
|
|
||||||
name_0="vllm_eager_outputs",
|
|
||||||
name_1="vllm_fullgraph_outputs",
|
|
||||||
)
|
|
@ -405,109 +405,6 @@ class TestAscendAttentionBackendImpl(TestBase):
|
|||||||
mock_paged_attention.assert_called_once()
|
mock_paged_attention.assert_called_once()
|
||||||
assert output.shape == (10, 8 * 64)
|
assert output.shape == (10, 8 * 64)
|
||||||
|
|
||||||
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
|
||||||
@patch('vllm_ascend.attention.attention_v1.get_graph_params')
|
|
||||||
@patch('torch_npu._npu_reshape_and_cache')
|
|
||||||
@patch('torch_npu._npu_paged_attention')
|
|
||||||
@patch('torch.npu.graph_task_group_end')
|
|
||||||
@patch('torch.npu.graph_task_group_begin')
|
|
||||||
@patch('torch.npu.ExternalEvent')
|
|
||||||
@patch('torch_npu.npu.current_stream')
|
|
||||||
def test_paged_attention_with_existing_workspace(
|
|
||||||
self,
|
|
||||||
mock_get_forward_context,
|
|
||||||
mock_get_graph_params,
|
|
||||||
mock_npu_reshape_and_cache,
|
|
||||||
mock_paged_attention,
|
|
||||||
mock_graph_begin,
|
|
||||||
mock_graph_end,
|
|
||||||
mock_external_event_class,
|
|
||||||
mock_current_stream,
|
|
||||||
):
|
|
||||||
graph_params = MagicMock()
|
|
||||||
attn_metadata = MagicMock()
|
|
||||||
num_tokens = 10
|
|
||||||
|
|
||||||
graph_params.workspaces = {num_tokens: 10}
|
|
||||||
graph_params.events = {num_tokens: []}
|
|
||||||
graph_params.attn_params = {num_tokens: []}
|
|
||||||
graph_params.handles = {num_tokens: []}
|
|
||||||
|
|
||||||
query = torch.randn(2, 5, 8) # [batch_size, seq_len, hidden_size]
|
|
||||||
key_cache = MagicMock()
|
|
||||||
value_cache = MagicMock()
|
|
||||||
num_kv_heads = 4
|
|
||||||
num_heads = 8
|
|
||||||
scale = 0.1
|
|
||||||
output = torch.randn(2, 5, 8)
|
|
||||||
|
|
||||||
self_obj = MagicMock()
|
|
||||||
self_obj.key_cache = key_cache
|
|
||||||
self_obj.value_cache = value_cache
|
|
||||||
self_obj.num_kv_heads = num_kv_heads
|
|
||||||
self_obj.num_heads = num_heads
|
|
||||||
self_obj.scale = scale
|
|
||||||
|
|
||||||
mock_stream = MagicMock()
|
|
||||||
mock_current_stream.return_value = mock_stream
|
|
||||||
mock_event_instance = MagicMock()
|
|
||||||
mock_external_event_class.return_value = mock_event_instance
|
|
||||||
|
|
||||||
mock_handle = MagicMock()
|
|
||||||
mock_graph_end.return_value = mock_handle
|
|
||||||
|
|
||||||
workspace = graph_params.workspaces.get(num_tokens)
|
|
||||||
self.assertEqual(workspace, 10)
|
|
||||||
|
|
||||||
# 2. Handle graph capturing mode
|
|
||||||
stream = mock_current_stream()
|
|
||||||
event = mock_external_event_class()
|
|
||||||
event.wait(stream)
|
|
||||||
event.reset(stream)
|
|
||||||
graph_params.events[num_tokens].append(event)
|
|
||||||
graph_params.attn_params[num_tokens].append((
|
|
||||||
query,
|
|
||||||
self_obj.key_cache,
|
|
||||||
self_obj.value_cache,
|
|
||||||
self_obj.num_kv_heads,
|
|
||||||
self_obj.num_heads,
|
|
||||||
self_obj.scale,
|
|
||||||
attn_metadata.block_tables,
|
|
||||||
attn_metadata.seq_lens,
|
|
||||||
output,
|
|
||||||
))
|
|
||||||
|
|
||||||
mock_event_instance.wait.assert_called_once_with(mock_stream)
|
|
||||||
mock_event_instance.reset.assert_called_once_with(mock_stream)
|
|
||||||
self.assertEqual(len(graph_params.events[num_tokens]), 1)
|
|
||||||
self.assertEqual(len(graph_params.attn_params[num_tokens]), 1)
|
|
||||||
|
|
||||||
query = torch.randn(10, 8 * 64)
|
|
||||||
key = torch.randn(10, 8 * 64)
|
|
||||||
value = torch.randn(10, 8 * 64)
|
|
||||||
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
|
||||||
metadata = self.attn_metadata
|
|
||||||
metadata.attn_state = AscendAttentionState.DecodeOnly
|
|
||||||
metadata.seq_lens = torch.tensor([10])
|
|
||||||
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
|
|
||||||
metadata.num_actual_tokens = 10
|
|
||||||
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
|
|
||||||
layer = self.layer_no_quant
|
|
||||||
|
|
||||||
mock_get_forward_context.return_value = MagicMock(capturing=True)
|
|
||||||
mock_get_graph_params.return_value = graph_params
|
|
||||||
|
|
||||||
output = self.impl.forward(layer,
|
|
||||||
query,
|
|
||||||
key,
|
|
||||||
value,
|
|
||||||
kv_cache,
|
|
||||||
metadata,
|
|
||||||
trace_flag=False)
|
|
||||||
|
|
||||||
mock_paged_attention.assert_called_once()
|
|
||||||
self.assertEqual(len(graph_params.handles[num_tokens]), 0)
|
|
||||||
|
|
||||||
@patch('torch_npu._npu_reshape_and_cache')
|
@patch('torch_npu._npu_reshape_and_cache')
|
||||||
@patch('torch_npu.npu_fused_infer_attention_score')
|
@patch('torch_npu.npu_fused_infer_attention_score')
|
||||||
def test_forward_decode_only_swa(self, mock_fused_infer_attention_score,
|
def test_forward_decode_only_swa(self, mock_fused_infer_attention_score,
|
||||||
|
@ -24,7 +24,7 @@ def mock_add_rms_norm(x, residual, weight, eps):
|
|||||||
|
|
||||||
|
|
||||||
def mock_add_rms_norm_quant(x, residual, weight, quant_scale, quant_offset,
|
def mock_add_rms_norm_quant(x, residual, weight, quant_scale, quant_offset,
|
||||||
beta, epsilon):
|
epsilon):
|
||||||
x_out = 2 * x
|
x_out = 2 * x
|
||||||
residual_out = 2 * residual
|
residual_out = 2 * residual
|
||||||
x_out_quant = x_out.to(torch.int8)
|
x_out_quant = x_out.to(torch.int8)
|
||||||
@ -94,7 +94,7 @@ class TestAscendRMSNorm(PytestBase):
|
|||||||
mock_model_instance = mocker.MagicMock()
|
mock_model_instance = mocker.MagicMock()
|
||||||
mock_forward_context.model_instance = mock_model_instance
|
mock_forward_context.model_instance = mock_model_instance
|
||||||
mock_model_instance.model.layers = [
|
mock_model_instance.model.layers = [
|
||||||
mocker.MagicMock() for _ in range(3)
|
mocker.MagicMock() for _ in range(2)
|
||||||
]
|
]
|
||||||
|
|
||||||
mock_layer_0 = mock_model_instance.model.layers[0]
|
mock_layer_0 = mock_model_instance.model.layers[0]
|
||||||
@ -124,7 +124,7 @@ class TestAscendRMSNorm(PytestBase):
|
|||||||
mock_forward_context.addrmsnorm_quant_fusion_enabled = True
|
mock_forward_context.addrmsnorm_quant_fusion_enabled = True
|
||||||
mock_forward_context.prefetch_mlp_enabled = False
|
mock_forward_context.prefetch_mlp_enabled = False
|
||||||
mock_forward_context.layer_idx = 0
|
mock_forward_context.layer_idx = 0
|
||||||
mock_forward_context.num_hidden_layers = 3
|
mock_forward_context.num_hidden_layers = 2
|
||||||
mock_forward_context.fusion_linear = "gate_up_dense"
|
mock_forward_context.fusion_linear = "gate_up_dense"
|
||||||
|
|
||||||
# Ensure fusion and layer_idx increment are handled correctly
|
# Ensure fusion and layer_idx increment are handled correctly
|
||||||
@ -144,32 +144,18 @@ class TestAscendRMSNorm(PytestBase):
|
|||||||
assert mock_forward_context.fusion_linear == "gate_up_dense"
|
assert mock_forward_context.fusion_linear == "gate_up_dense"
|
||||||
assert mock_forward_context.layer_idx == 1
|
assert mock_forward_context.layer_idx == 1
|
||||||
|
|
||||||
mock_forward_context.fusion_linear = "gate_moe"
|
|
||||||
x_out, residual_out = layer.forward_oot(x, residual)
|
x_out, residual_out = layer.forward_oot(x, residual)
|
||||||
|
|
||||||
assert mock_get_forward_context.call_count == 3
|
assert mock_get_forward_context.call_count == 3
|
||||||
assert mock_forward_context.fusion_linear == "qkv_moe"
|
assert mock_forward_context.fusion_linear == "qkv_dense"
|
||||||
assert mock_forward_context.layer_idx == 2
|
assert mock_forward_context.layer_idx == 2
|
||||||
|
|
||||||
x_out, residual_out = layer.forward_oot(x, residual)
|
x_out, residual_out = layer.forward_oot(x, residual)
|
||||||
|
|
||||||
assert mock_get_forward_context.call_count == 4
|
assert mock_get_forward_context.call_count == 4
|
||||||
assert mock_forward_context.fusion_linear == "gate_moe"
|
assert mock_forward_context.fusion_linear == "qkv_dense"
|
||||||
assert mock_forward_context.layer_idx == 2
|
assert mock_forward_context.layer_idx == 2
|
||||||
|
|
||||||
# last layer returned directly
|
|
||||||
x_out, residual_out = layer.forward_oot(x, residual)
|
|
||||||
|
|
||||||
assert mock_get_forward_context.call_count == 5
|
|
||||||
assert mock_forward_context.fusion_linear == "qkv_moe"
|
|
||||||
assert mock_forward_context.layer_idx == 3
|
|
||||||
|
|
||||||
x_out, residual_out = layer.forward_oot(x, residual)
|
|
||||||
|
|
||||||
assert mock_get_forward_context.call_count == 6
|
|
||||||
assert mock_forward_context.fusion_linear == "qkv_moe"
|
|
||||||
assert mock_forward_context.layer_idx == 3
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from tests.ut.base import TestBase
|
from tests.ut.base import TestBase
|
||||||
@ -17,10 +16,6 @@ class TestAscendW8A8FusedMoEMethod(TestBase):
|
|||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
dtype=torch.bfloat16)
|
dtype=torch.bfloat16)
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
|
||||||
True,
|
|
||||||
reason="fix me",
|
|
||||||
)
|
|
||||||
@patch("torch.distributed.all_to_all_single")
|
@patch("torch.distributed.all_to_all_single")
|
||||||
@patch("torch_npu.npu_moe_re_routing")
|
@patch("torch_npu.npu_moe_re_routing")
|
||||||
@patch("torch_npu.npu_grouped_matmul")
|
@patch("torch_npu.npu_grouped_matmul")
|
||||||
|
@ -156,14 +156,12 @@ def set_ascend_forward_context(
|
|||||||
# Once the necessary conditions are met, support for MOE models will also be added.
|
# Once the necessary conditions are met, support for MOE models will also be added.
|
||||||
from vllm_ascend.quantization.quant_config import AscendQuantConfig
|
from vllm_ascend.quantization.quant_config import AscendQuantConfig
|
||||||
addrmsnorm_quant_fusion_enabled = isinstance(vllm_config.quant_config, AscendQuantConfig) and \
|
addrmsnorm_quant_fusion_enabled = isinstance(vllm_config.quant_config, AscendQuantConfig) and \
|
||||||
vllm_config.model_config.hf_config.model_type in ["llama", "qwen2", "qwen3", "qwen3_moe"] and \
|
vllm_config.model_config.hf_config.model_type in ["llama", "qwen2", "qwen3"] and \
|
||||||
forward_context.layer_idx is not None
|
forward_context.layer_idx is not None
|
||||||
if addrmsnorm_quant_fusion_enabled:
|
if addrmsnorm_quant_fusion_enabled:
|
||||||
forward_context.model_instance = model_instance
|
forward_context.model_instance = model_instance
|
||||||
forward_context.num_hidden_layers = vllm_config.model_config.hf_config.num_hidden_layers
|
forward_context.num_hidden_layers = vllm_config.model_config.hf_config.num_hidden_layers
|
||||||
forward_context.fusion_linear = "gate_up_dense" if forward_context.layer_idx == 0 else "qkv_dense"
|
forward_context.fusion_linear = "gate_up_dense" if forward_context.layer_idx == 0 else "qkv_dense"
|
||||||
if vllm_config.model_config.hf_config.model_type == "qwen3_moe":
|
|
||||||
forward_context.fusion_linear = "gate_moe" if forward_context.layer_idx == 0 else "qkv_moe"
|
|
||||||
forward_context.addrmsnorm_quant_fusion_enabled = addrmsnorm_quant_fusion_enabled
|
forward_context.addrmsnorm_quant_fusion_enabled = addrmsnorm_quant_fusion_enabled
|
||||||
|
|
||||||
if num_tokens is None and attn_metadata is not None:
|
if num_tokens is None and attn_metadata is not None:
|
||||||
|
@ -34,8 +34,7 @@ from vllm.v1.kv_cache_interface import AttentionSpec
|
|||||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||||
maybe_save_kv_layer_to_connector,
|
maybe_save_kv_layer_to_connector,
|
||||||
wait_for_kv_layer_from_connector)
|
wait_for_kv_layer_from_connector)
|
||||||
from vllm_ascend.compilation.acl_graph import (get_graph_params,
|
from vllm_ascend.compilation.acl_graph import get_graph_params
|
||||||
update_graph_params_workspaces)
|
|
||||||
from vllm_ascend.ops.attention import vanilla_chunked_prefill
|
from vllm_ascend.ops.attention import vanilla_chunked_prefill
|
||||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
|
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
|
||||||
nd_to_nz_2d, nd_to_nz_spec)
|
nd_to_nz_2d, nd_to_nz_spec)
|
||||||
@ -394,28 +393,13 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
forward_context: ForwardContext = get_forward_context()
|
forward_context: ForwardContext = get_forward_context()
|
||||||
num_tokens = query.shape[0]
|
num_tokens = query.shape[0]
|
||||||
if forward_context.capturing:
|
if forward_context.capturing:
|
||||||
# Get workspace from cache or calculate it if not present.
|
|
||||||
workspace = graph_params.workspaces.get(num_tokens)
|
|
||||||
if workspace is None:
|
|
||||||
workspace = torch_npu._npu_paged_attention_get_workspace(
|
|
||||||
query=query,
|
|
||||||
key_cache=self.key_cache,
|
|
||||||
value_cache=self.value_cache,
|
|
||||||
num_kv_heads=self.num_kv_heads,
|
|
||||||
num_heads=self.num_heads,
|
|
||||||
scale_value=self.scale,
|
|
||||||
block_table=attn_metadata.block_tables,
|
|
||||||
context_lens=attn_metadata.seq_lens,
|
|
||||||
out=output)
|
|
||||||
update_graph_params_workspaces(num_tokens, workspace)
|
|
||||||
|
|
||||||
# Handle graph capturing mode
|
|
||||||
stream = torch_npu.npu.current_stream()
|
stream = torch_npu.npu.current_stream()
|
||||||
|
|
||||||
event = torch.npu.ExternalEvent()
|
event = torch.npu.ExternalEvent()
|
||||||
event.wait(stream)
|
event.wait(stream)
|
||||||
event.reset(stream)
|
event.reset(stream)
|
||||||
graph_params.events[num_tokens].append(event)
|
graph_params.events[num_tokens].append(event)
|
||||||
|
|
||||||
graph_params.attn_params[num_tokens].append((
|
graph_params.attn_params[num_tokens].append((
|
||||||
query,
|
query,
|
||||||
self.key_cache,
|
self.key_cache,
|
||||||
@ -429,7 +413,6 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
))
|
))
|
||||||
|
|
||||||
torch.npu.graph_task_group_begin(stream)
|
torch.npu.graph_task_group_begin(stream)
|
||||||
|
|
||||||
torch_npu._npu_paged_attention(
|
torch_npu._npu_paged_attention(
|
||||||
query=query,
|
query=query,
|
||||||
key_cache=self.key_cache,
|
key_cache=self.key_cache,
|
||||||
@ -439,8 +422,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
scale_value=self.scale,
|
scale_value=self.scale,
|
||||||
block_table=attn_metadata.block_tables,
|
block_table=attn_metadata.block_tables,
|
||||||
context_lens=attn_metadata.seq_lens,
|
context_lens=attn_metadata.seq_lens,
|
||||||
out=output,
|
out=output)
|
||||||
workspace=workspace)
|
|
||||||
handle = torch.npu.graph_task_group_end(stream)
|
handle = torch.npu.graph_task_group_end(stream)
|
||||||
graph_params.handles[num_tokens].append(handle)
|
graph_params.handles[num_tokens].append(handle)
|
||||||
else:
|
else:
|
||||||
|
@ -215,17 +215,15 @@ def update_attn_params(update_stream, forward_context, runtime_shape):
|
|||||||
|
|
||||||
with torch.npu.stream(update_stream):
|
with torch.npu.stream(update_stream):
|
||||||
torch.npu.graph_task_update_begin(update_stream, handle)
|
torch.npu.graph_task_update_begin(update_stream, handle)
|
||||||
torch_npu._npu_paged_attention(
|
torch_npu._npu_paged_attention(query=query,
|
||||||
query=query,
|
key_cache=key_cache,
|
||||||
key_cache=key_cache,
|
value_cache=value_cache,
|
||||||
value_cache=value_cache,
|
num_kv_heads=num_kv_heads,
|
||||||
num_kv_heads=num_kv_heads,
|
num_heads=num_heads,
|
||||||
num_heads=num_heads,
|
scale_value=scale,
|
||||||
scale_value=scale,
|
block_table=block_table,
|
||||||
block_table=block_table,
|
context_lens=seq_lens,
|
||||||
context_lens=seq_lens,
|
out=output)
|
||||||
out=output,
|
|
||||||
workspace=graph_params.workspaces.get(runtime_shape))
|
|
||||||
torch.npu.graph_task_update_end(update_stream)
|
torch.npu.graph_task_update_end(update_stream)
|
||||||
|
|
||||||
event.record(update_stream)
|
event.record(update_stream)
|
||||||
@ -258,11 +256,5 @@ def set_graph_params(aclgraph_capture_sizes: set[int]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def update_graph_params_workspaces(num_tokens: int, workspace: int):
|
|
||||||
global _graph_params
|
|
||||||
if _graph_params is not None:
|
|
||||||
_graph_params.workspaces[num_tokens] = workspace
|
|
||||||
|
|
||||||
|
|
||||||
def get_graph_params():
|
def get_graph_params():
|
||||||
return _graph_params
|
return _graph_params
|
||||||
|
@ -15,10 +15,9 @@
|
|||||||
# This file is a part of the vllm-ascend project.
|
# This file is a part of the vllm-ascend project.
|
||||||
#
|
#
|
||||||
|
|
||||||
from typing import Optional, Tuple, Union
|
from typing import Optional, Tuple, Union, cast
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from vllm.config import get_current_vllm_config
|
|
||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
|
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
|
||||||
|
|
||||||
@ -28,7 +27,6 @@ def _addrmsnorm_forward_oot(
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
residual: torch.Tensor,
|
residual: torch.Tensor,
|
||||||
layer: Optional[torch.nn.Module] = None,
|
layer: Optional[torch.nn.Module] = None,
|
||||||
bias: Optional[torch.nn.Parameter] = None,
|
|
||||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||||
import torch_npu
|
import torch_npu
|
||||||
|
|
||||||
@ -41,7 +39,6 @@ def _addrmsnorm_forward_oot(
|
|||||||
self.weight,
|
self.weight,
|
||||||
layer.aclnn_input_scale,
|
layer.aclnn_input_scale,
|
||||||
layer.aclnn_input_offset,
|
layer.aclnn_input_offset,
|
||||||
beta=bias,
|
|
||||||
epsilon=self.variance_epsilon)
|
epsilon=self.variance_epsilon)
|
||||||
else:
|
else:
|
||||||
if is_310p():
|
if is_310p():
|
||||||
@ -53,31 +50,12 @@ def _addrmsnorm_forward_oot(
|
|||||||
else:
|
else:
|
||||||
x, _, residual = torch_npu.npu_add_rms_norm(
|
x, _, residual = torch_npu.npu_add_rms_norm(
|
||||||
x, residual, self.weight, self.variance_epsilon)
|
x, residual, self.weight, self.variance_epsilon)
|
||||||
if bias is not None:
|
|
||||||
x.add_(bias)
|
|
||||||
torch.ops.vllm.maybe_wait_prefetch_done(x)
|
torch.ops.vllm.maybe_wait_prefetch_done(x)
|
||||||
return x, residual
|
return x, residual
|
||||||
|
|
||||||
|
|
||||||
class AscendRMSNorm(RMSNorm):
|
class AscendRMSNorm(RMSNorm):
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
hidden_size: int,
|
|
||||||
eps: float = 1e-6,
|
|
||||||
var_hidden_size: Optional[int] = None,
|
|
||||||
has_weight: bool = True,
|
|
||||||
dtype: Optional[torch.dtype] = None,
|
|
||||||
) -> None:
|
|
||||||
super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype)
|
|
||||||
vllm_config = get_current_vllm_config()
|
|
||||||
self.bias = None
|
|
||||||
# quantization with anti_method m4 will generate none-zero norm bias
|
|
||||||
if vllm_config is not None and vllm_config.quant_config is not None and \
|
|
||||||
any("norm.bias" in name for name in vllm_config.quant_config.quant_description.keys()):
|
|
||||||
self.bias = torch.nn.Parameter(torch.zeros(hidden_size),
|
|
||||||
requires_grad=False)
|
|
||||||
|
|
||||||
def forward_oot(
|
def forward_oot(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
@ -89,13 +67,10 @@ class AscendRMSNorm(RMSNorm):
|
|||||||
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
|
residual = torch.ops.vllm.maybe_chunk_residual(x, residual)
|
||||||
assert x.size(0) == residual.size(0)
|
assert x.size(0) == residual.size(0)
|
||||||
x, residual = _addrmsnorm_forward_oot(
|
x, residual = _addrmsnorm_forward_oot(
|
||||||
self, x, residual, self.next_need_quant_fusion_linear,
|
self, x, residual, self.next_need_quant_fusion_linear)
|
||||||
self.bias)
|
|
||||||
return x, residual
|
return x, residual
|
||||||
x, residual = torch_npu.npu_rms_norm(x, self.weight,
|
x, residual = torch_npu.npu_rms_norm(x, self.weight,
|
||||||
self.variance_epsilon)
|
self.variance_epsilon)
|
||||||
if self.bias is not None:
|
|
||||||
x.add_(self.bias)
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -125,13 +100,6 @@ class AscendRMSNorm(RMSNorm):
|
|||||||
# does not need to be repeated
|
# does not need to be repeated
|
||||||
if not forward_context.prefetch_mlp_enabled:
|
if not forward_context.prefetch_mlp_enabled:
|
||||||
forward_context.layer_idx += 1
|
forward_context.layer_idx += 1
|
||||||
elif fusion_linear == "qkv_moe":
|
|
||||||
next_linear = model_instance.model.layers[
|
|
||||||
layer_idx].self_attn.qkv_proj
|
|
||||||
forward_context.fusion_linear = "gate_moe"
|
|
||||||
elif fusion_linear == "gate_moe":
|
|
||||||
forward_context.fusion_linear = "qkv_moe"
|
|
||||||
forward_context.layer_idx += 1
|
|
||||||
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
|
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
|
||||||
if next_linear is not None and \
|
if next_linear is not None and \
|
||||||
not isinstance(next_linear.quant_method.quant_method, AscendW8A8LinearMethod):
|
not isinstance(next_linear.quant_method.quant_method, AscendW8A8LinearMethod):
|
||||||
@ -139,6 +107,31 @@ class AscendRMSNorm(RMSNorm):
|
|||||||
return next_linear
|
return next_linear
|
||||||
|
|
||||||
|
|
||||||
|
class AscendQuantRMSNorm(AscendRMSNorm):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
eps: float = 1e-6,
|
||||||
|
var_hidden_size: Optional[int] = None,
|
||||||
|
has_weight: bool = True,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(hidden_size, eps, var_hidden_size, has_weight, dtype)
|
||||||
|
self.bias = torch.nn.Parameter(torch.zeros(hidden_size),
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
|
def forward_oot(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
residual: Optional[torch.Tensor] = None,
|
||||||
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||||
|
if residual is not None:
|
||||||
|
x, residual = super().forward_oot(x, residual)
|
||||||
|
return x.add_(self.bias), residual
|
||||||
|
return cast(torch.Tensor, super().forward_oot(x)).add_(self.bias)
|
||||||
|
|
||||||
|
|
||||||
class AscendGemmaRMSNorm(GemmaRMSNorm):
|
class AscendGemmaRMSNorm(GemmaRMSNorm):
|
||||||
|
|
||||||
def forward_oot(
|
def forward_oot(
|
||||||
|
@ -501,7 +501,8 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
|
|||||||
from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul
|
from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul
|
||||||
from vllm_ascend.ops.common_fused_moe import (AscendFusedMoE,
|
from vllm_ascend.ops.common_fused_moe import (AscendFusedMoE,
|
||||||
AscendSharedFusedMoE)
|
AscendSharedFusedMoE)
|
||||||
from vllm_ascend.ops.layernorm import AscendGemmaRMSNorm, AscendRMSNorm
|
from vllm_ascend.ops.layernorm import (AscendGemmaRMSNorm,
|
||||||
|
AscendQuantRMSNorm, AscendRMSNorm)
|
||||||
from vllm_ascend.ops.linear import (AscendColumnParallelLinear,
|
from vllm_ascend.ops.linear import (AscendColumnParallelLinear,
|
||||||
AscendMergedColumnParallelLinear,
|
AscendMergedColumnParallelLinear,
|
||||||
AscendQKVParallelLinear,
|
AscendQKVParallelLinear,
|
||||||
@ -532,6 +533,11 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
|
|||||||
"MultiHeadLatentAttention": AscendMultiHeadLatentAttention,
|
"MultiHeadLatentAttention": AscendMultiHeadLatentAttention,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if vllm_config is not None and \
|
||||||
|
vllm_config.quant_config is not None and \
|
||||||
|
any("norm.bias" in name for name in vllm_config.quant_config.quant_description.keys()):
|
||||||
|
REGISTERED_ASCEND_OPS["RMSNorm"] = AscendQuantRMSNorm
|
||||||
|
|
||||||
for name, op_cls in REGISTERED_ASCEND_OPS.items():
|
for name, op_cls in REGISTERED_ASCEND_OPS.items():
|
||||||
CustomOp.register_oot(_decorated_op_cls=op_cls, name=name)
|
CustomOp.register_oot(_decorated_op_cls=op_cls, name=name)
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user