fix pagedattention to support fullgraph. (#3436)

### What this PR does / why we need it?
Calculate in advance the workspace memory size needed for the
PagedAttention operator to avoid deadlocks during resource cleanup. This
PR requires torch_npu version 0920 or newer.
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?


- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

Signed-off-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
Co-authored-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
This commit is contained in:
XiaoxinWang
2025-10-14 16:10:09 +08:00
committed by GitHub
parent 22a1d91cf5
commit 9eb62935b8
5 changed files with 271 additions and 21 deletions

View File

@ -0,0 +1,72 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py
#
"""Compare the short outputs of HF and vLLM when using greedy sampling.
Run `pytest tests/e2e/multicard/test_qwen3_moe.py`.
"""
import os
from vllm import SamplingParams
from tests.e2e.conftest import VllmRunner
from tests.e2e.model_utils import check_outputs_equal
def test_models_distributed_Qwen3_MOE_TP2_WITH_FULLGRAPH():
if 'HCCL_OP_EXPANSION_MODE' in os.environ:
del os.environ['HCCL_OP_EXPANSION_MODE']
prompts = [
"Hello, my name is", "The president of the United States is",
"The capital of France is", "The future of AI is"
]
model = "Qwen/Qwen3-30B-A3B"
sampling_params = SamplingParams(max_tokens=32, temperature=0.0)
with VllmRunner(model,
max_model_len=1024,
tensor_parallel_size=2,
enforce_eager=False,
compilation_config={"cudagraph_mode":
"FULL_DECODE_ONLY"}) as runner:
vllm_fullgraph_outputs = runner.model.generate(prompts,
sampling_params)
with VllmRunner(
model,
max_model_len=1024,
enforce_eager=True,
) as runner:
vllm_eager_outputs = runner.model.generate(prompts, sampling_params)
vllm_fullgraph_outputs_list = []
for output in vllm_fullgraph_outputs:
vllm_fullgraph_outputs_list.append(
(output.outputs[0].index, output.outputs[0].text))
vllm_eager_outputs_list = []
for output in vllm_eager_outputs:
vllm_eager_outputs_list.append(
(output.outputs[0].index, output.outputs[0].text))
check_outputs_equal(
outputs_0_lst=vllm_eager_outputs_list,
outputs_1_lst=vllm_fullgraph_outputs_list,
name_0="vllm_eager_outputs",
name_1="vllm_fullgraph_outputs",
)

View File

@ -405,6 +405,113 @@ class TestAscendAttentionBackendImpl(TestBase):
mock_paged_attention.assert_called_once() mock_paged_attention.assert_called_once()
assert output.shape == (10, 8 * 64) assert output.shape == (10, 8 * 64)
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
@patch('vllm_ascend.attention.attention_v1.get_graph_params')
@patch('torch_npu._npu_reshape_and_cache')
@patch('torch_npu._npu_paged_attention')
@patch('torch.npu.graph_task_group_end')
@patch('torch.npu.graph_task_group_begin')
@patch('torch.npu.ExternalEvent')
@patch('torch_npu.npu.current_stream')
@patch('vllm_ascend.attention.attention_v1.weak_ref_tensors')
def test_paged_attention_with_existing_workspace(
self,
mock_get_forward_context,
mock_get_graph_params,
mock_npu_reshape_and_cache,
mock_paged_attention,
mock_graph_begin,
mock_graph_end,
mock_external_event_class,
mock_current_stream,
mock_weak_ref_tensors,
):
graph_params = MagicMock()
attn_metadata = MagicMock()
num_tokens = 10
graph_params.workspaces = {num_tokens: 10}
graph_params.events = {num_tokens: []}
graph_params.attn_params = {num_tokens: []}
graph_params.handles = {num_tokens: []}
query = torch.randn(2, 5, 8) # [batch_size, seq_len, hidden_size]
key_cache = MagicMock()
value_cache = MagicMock()
num_kv_heads = 4
num_heads = 8
scale = 0.1
output = torch.randn(2, 5, 8)
self_obj = MagicMock()
self_obj.key_cache = key_cache
self_obj.value_cache = value_cache
self_obj.num_kv_heads = num_kv_heads
self_obj.num_heads = num_heads
self_obj.scale = scale
mock_stream = MagicMock()
mock_current_stream.return_value = mock_stream
mock_event_instance = MagicMock()
mock_external_event_class.return_value = mock_event_instance
mock_handle = MagicMock()
mock_graph_end.return_value = mock_handle
workspace = graph_params.workspaces.get(num_tokens)
self.assertEqual(workspace, 10)
weak_ref_tensors = MagicMock(side_effect=lambda x: x)
# 2. Handle graph capturing mode
stream = mock_current_stream()
event = mock_external_event_class()
event.wait(stream)
event.reset(stream)
graph_params.events[num_tokens].append(event)
graph_params.attn_params[num_tokens].append((
weak_ref_tensors(query),
weak_ref_tensors(self_obj.key_cache),
weak_ref_tensors(self_obj.value_cache),
self_obj.num_kv_heads,
self_obj.num_heads,
self_obj.scale,
weak_ref_tensors(attn_metadata.block_tables),
attn_metadata.seq_lens,
output,
))
mock_event_instance.wait.assert_called_once_with(mock_stream)
mock_event_instance.reset.assert_called_once_with(mock_stream)
self.assertEqual(len(graph_params.events[num_tokens]), 1)
self.assertEqual(len(graph_params.attn_params[num_tokens]), 1)
query = torch.randn(10, 8 * 64)
key = torch.randn(10, 8 * 64)
value = torch.randn(10, 8 * 64)
kv_cache = torch.empty(2, 5, 128, 8, 64)
metadata = self.attn_metadata
metadata.attn_state = AscendAttentionState.DecodeOnly
metadata.seq_lens = torch.tensor([10])
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
metadata.num_actual_tokens = 10
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
layer = self.layer_no_quant
mock_get_forward_context.return_value = MagicMock(capturing=True)
mock_get_graph_params.return_value = graph_params
output = self.impl.forward(layer,
query,
key,
value,
kv_cache,
metadata,
trace_flag=False)
mock_paged_attention.assert_called_once()
self.assertEqual(len(graph_params.handles[num_tokens]), 0)
@patch('torch_npu._npu_reshape_and_cache') @patch('torch_npu._npu_reshape_and_cache')
@patch('torch_npu.npu_fused_infer_attention_score') @patch('torch_npu.npu_fused_infer_attention_score')
def test_forward_decode_only_swa(self, mock_fused_infer_attention_score, def test_forward_decode_only_swa(self, mock_fused_infer_attention_score,

View File

@ -33,8 +33,10 @@ from vllm.v1.kv_cache_interface import AttentionSpec
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
maybe_save_kv_layer_to_connector, maybe_save_kv_layer_to_connector,
version_check,
wait_for_kv_layer_from_connector) wait_for_kv_layer_from_connector)
from vllm_ascend.compilation.acl_graph import get_graph_params from vllm_ascend.compilation.acl_graph import (get_graph_params,
update_graph_params_workspaces)
from vllm_ascend.ops.attention import vanilla_chunked_prefill from vllm_ascend.ops.attention import vanilla_chunked_prefill
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p, from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
nd_to_nz_2d, nd_to_nz_spec) nd_to_nz_2d, nd_to_nz_spec)
@ -289,6 +291,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.key_cache = None self.key_cache = None
self.value_cache = None self.value_cache = None
self.torch_npu_check = version_check()
def _forward_prefill_no_cache( def _forward_prefill_no_cache(
self, self,
@ -396,13 +399,29 @@ class AscendAttentionBackendImpl(AttentionImpl):
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
num_tokens = query.shape[0] num_tokens = query.shape[0]
if forward_context.capturing: if forward_context.capturing:
if self.torch_npu_check:
# Get workspace from cache or calculate it if not present.
workspace = graph_params.workspaces.get(num_tokens)
if workspace is None:
workspace = torch_npu._npu_paged_attention_get_workspace(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=attn_metadata.block_tables,
context_lens=attn_metadata.seq_lens,
out=output)
update_graph_params_workspaces(num_tokens, workspace)
# Handle graph capturing mode
stream = torch_npu.npu.current_stream() stream = torch_npu.npu.current_stream()
event = torch.npu.ExternalEvent() event = torch.npu.ExternalEvent()
event.wait(stream) event.wait(stream)
event.reset(stream) event.reset(stream)
graph_params.events[num_tokens].append(event) graph_params.events[num_tokens].append(event)
graph_params.attn_params[num_tokens].append(( graph_params.attn_params[num_tokens].append((
weak_ref_tensors(query), weak_ref_tensors(query),
weak_ref_tensors(self.key_cache), weak_ref_tensors(self.key_cache),
@ -416,16 +435,30 @@ class AscendAttentionBackendImpl(AttentionImpl):
)) ))
torch.npu.graph_task_group_begin(stream) torch.npu.graph_task_group_begin(stream)
torch_npu._npu_paged_attention(
query=query, if self.torch_npu_check:
key_cache=self.key_cache, torch_npu._npu_paged_attention(
value_cache=self.value_cache, query=query,
num_kv_heads=self.num_kv_heads, key_cache=self.key_cache,
num_heads=self.num_heads, value_cache=self.value_cache,
scale_value=self.scale, num_kv_heads=self.num_kv_heads,
block_table=attn_metadata.block_tables, num_heads=self.num_heads,
context_lens=attn_metadata.seq_lens, scale_value=self.scale,
out=output) block_table=attn_metadata.block_tables,
context_lens=attn_metadata.seq_lens,
out=output,
workspace=workspace)
else:
torch_npu._npu_paged_attention(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=attn_metadata.block_tables,
context_lens=attn_metadata.seq_lens,
out=output)
handle = torch.npu.graph_task_group_end(stream) handle = torch.npu.graph_task_group_end(stream)
graph_params.handles[num_tokens].append(handle) graph_params.handles[num_tokens].append(handle)
else: else:

View File

@ -1,7 +1,9 @@
import functools
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, List from typing import Any, List
import torch import torch
import torch_npu
from vllm.distributed.kv_transfer import (get_kv_transfer_group, from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group, has_kv_transfer_group,
is_v1_kv_transfer_group) is_v1_kv_transfer_group)
@ -139,3 +141,17 @@ def maybe_save_kv_layer_to_connector(
return return
# TODO: assert ascendMetadata # TODO: assert ascendMetadata
connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata) connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata)
@functools.cache
def version_check():
import re
torch_npu_version = torch_npu.version.__version__
date_pattern = r'dev(\d{8})'
match = re.search(date_pattern, torch_npu_version)
if match:
full_date = match.group(1)
if full_date >= "20250919":
return True
return False

View File

@ -18,6 +18,8 @@ from vllm.forward_context import BatchDescriptor, get_forward_context
from vllm.logger import logger from vllm.logger import logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm_ascend.attention.utils import version_check
from ..utils import weak_ref_tensors from ..utils import weak_ref_tensors
@ -212,18 +214,32 @@ def update_attn_params(update_stream, forward_context, runtime_shape):
) = param ) = param
# block_table = forward_context.attn_metadata[key].block_tables # block_table = forward_context.attn_metadata[key].block_tables
seq_lens = forward_context.attn_metadata[key].seq_lens seq_lens = forward_context.attn_metadata[key].seq_lens
torch_npu_check = version_check()
with torch.npu.stream(update_stream): with torch.npu.stream(update_stream):
torch.npu.graph_task_update_begin(update_stream, handle) torch.npu.graph_task_update_begin(update_stream, handle)
torch_npu._npu_paged_attention(query=query, if torch_npu_check:
key_cache=key_cache, torch_npu._npu_paged_attention(
value_cache=value_cache, query=query,
num_kv_heads=num_kv_heads, key_cache=key_cache,
num_heads=num_heads, value_cache=value_cache,
scale_value=scale, num_kv_heads=num_kv_heads,
block_table=block_table, num_heads=num_heads,
context_lens=seq_lens, scale_value=scale,
out=output) block_table=block_table,
context_lens=seq_lens,
out=output,
workspace=graph_params.workspaces.get(runtime_shape))
else:
torch_npu._npu_paged_attention(query=query,
key_cache=key_cache,
value_cache=value_cache,
num_kv_heads=num_kv_heads,
num_heads=num_heads,
scale_value=scale,
block_table=block_table,
context_lens=seq_lens,
out=output)
torch.npu.graph_task_update_end(update_stream) torch.npu.graph_task_update_end(update_stream)
event.record(update_stream) event.record(update_stream)
@ -302,5 +318,11 @@ def set_graph_params(aclgraph_capture_sizes: set[int]):
) )
def update_graph_params_workspaces(num_tokens: int, workspace: int):
global _graph_params
if _graph_params is not None:
_graph_params.workspaces[num_tokens] = workspace
def get_graph_params(): def get_graph_params():
return _graph_params return _graph_params