mirror of
https://github.com/vllm-project/vllm-ascend.git
synced 2025-10-20 13:43:53 +08:00
fix pagedattention to support fullgraph. (#3436)
### What this PR does / why we need it? Calculate in advance the workspace memory size needed for the PagedAttention operator to avoid deadlocks during resource cleanup. This PR requires torch_npu version 0920 or newer. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com> Co-authored-by: wangxiaoxin-sherie <wangxiaoxin7@huawei.com>
This commit is contained in:
72
tests/e2e/multicard/test_full_graph_mode.py
Normal file
72
tests/e2e/multicard/test_full_graph_mode.py
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# This file is a part of the vllm-ascend project.
|
||||||
|
# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py
|
||||||
|
#
|
||||||
|
"""Compare the short outputs of HF and vLLM when using greedy sampling.
|
||||||
|
|
||||||
|
Run `pytest tests/e2e/multicard/test_qwen3_moe.py`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
from vllm import SamplingParams
|
||||||
|
|
||||||
|
from tests.e2e.conftest import VllmRunner
|
||||||
|
from tests.e2e.model_utils import check_outputs_equal
|
||||||
|
|
||||||
|
|
||||||
|
def test_models_distributed_Qwen3_MOE_TP2_WITH_FULLGRAPH():
|
||||||
|
if 'HCCL_OP_EXPANSION_MODE' in os.environ:
|
||||||
|
del os.environ['HCCL_OP_EXPANSION_MODE']
|
||||||
|
prompts = [
|
||||||
|
"Hello, my name is", "The president of the United States is",
|
||||||
|
"The capital of France is", "The future of AI is"
|
||||||
|
]
|
||||||
|
model = "Qwen/Qwen3-30B-A3B"
|
||||||
|
sampling_params = SamplingParams(max_tokens=32, temperature=0.0)
|
||||||
|
with VllmRunner(model,
|
||||||
|
max_model_len=1024,
|
||||||
|
tensor_parallel_size=2,
|
||||||
|
enforce_eager=False,
|
||||||
|
compilation_config={"cudagraph_mode":
|
||||||
|
"FULL_DECODE_ONLY"}) as runner:
|
||||||
|
vllm_fullgraph_outputs = runner.model.generate(prompts,
|
||||||
|
sampling_params)
|
||||||
|
|
||||||
|
with VllmRunner(
|
||||||
|
model,
|
||||||
|
max_model_len=1024,
|
||||||
|
enforce_eager=True,
|
||||||
|
) as runner:
|
||||||
|
vllm_eager_outputs = runner.model.generate(prompts, sampling_params)
|
||||||
|
|
||||||
|
vllm_fullgraph_outputs_list = []
|
||||||
|
for output in vllm_fullgraph_outputs:
|
||||||
|
vllm_fullgraph_outputs_list.append(
|
||||||
|
(output.outputs[0].index, output.outputs[0].text))
|
||||||
|
|
||||||
|
vllm_eager_outputs_list = []
|
||||||
|
for output in vllm_eager_outputs:
|
||||||
|
vllm_eager_outputs_list.append(
|
||||||
|
(output.outputs[0].index, output.outputs[0].text))
|
||||||
|
|
||||||
|
check_outputs_equal(
|
||||||
|
outputs_0_lst=vllm_eager_outputs_list,
|
||||||
|
outputs_1_lst=vllm_fullgraph_outputs_list,
|
||||||
|
name_0="vllm_eager_outputs",
|
||||||
|
name_1="vllm_fullgraph_outputs",
|
||||||
|
)
|
@ -405,6 +405,113 @@ class TestAscendAttentionBackendImpl(TestBase):
|
|||||||
mock_paged_attention.assert_called_once()
|
mock_paged_attention.assert_called_once()
|
||||||
assert output.shape == (10, 8 * 64)
|
assert output.shape == (10, 8 * 64)
|
||||||
|
|
||||||
|
@patch('vllm_ascend.attention.attention_v1.get_forward_context')
|
||||||
|
@patch('vllm_ascend.attention.attention_v1.get_graph_params')
|
||||||
|
@patch('torch_npu._npu_reshape_and_cache')
|
||||||
|
@patch('torch_npu._npu_paged_attention')
|
||||||
|
@patch('torch.npu.graph_task_group_end')
|
||||||
|
@patch('torch.npu.graph_task_group_begin')
|
||||||
|
@patch('torch.npu.ExternalEvent')
|
||||||
|
@patch('torch_npu.npu.current_stream')
|
||||||
|
@patch('vllm_ascend.attention.attention_v1.weak_ref_tensors')
|
||||||
|
def test_paged_attention_with_existing_workspace(
|
||||||
|
self,
|
||||||
|
mock_get_forward_context,
|
||||||
|
mock_get_graph_params,
|
||||||
|
mock_npu_reshape_and_cache,
|
||||||
|
mock_paged_attention,
|
||||||
|
mock_graph_begin,
|
||||||
|
mock_graph_end,
|
||||||
|
mock_external_event_class,
|
||||||
|
mock_current_stream,
|
||||||
|
mock_weak_ref_tensors,
|
||||||
|
):
|
||||||
|
graph_params = MagicMock()
|
||||||
|
attn_metadata = MagicMock()
|
||||||
|
num_tokens = 10
|
||||||
|
|
||||||
|
graph_params.workspaces = {num_tokens: 10}
|
||||||
|
graph_params.events = {num_tokens: []}
|
||||||
|
graph_params.attn_params = {num_tokens: []}
|
||||||
|
graph_params.handles = {num_tokens: []}
|
||||||
|
|
||||||
|
query = torch.randn(2, 5, 8) # [batch_size, seq_len, hidden_size]
|
||||||
|
key_cache = MagicMock()
|
||||||
|
value_cache = MagicMock()
|
||||||
|
num_kv_heads = 4
|
||||||
|
num_heads = 8
|
||||||
|
scale = 0.1
|
||||||
|
output = torch.randn(2, 5, 8)
|
||||||
|
|
||||||
|
self_obj = MagicMock()
|
||||||
|
self_obj.key_cache = key_cache
|
||||||
|
self_obj.value_cache = value_cache
|
||||||
|
self_obj.num_kv_heads = num_kv_heads
|
||||||
|
self_obj.num_heads = num_heads
|
||||||
|
self_obj.scale = scale
|
||||||
|
|
||||||
|
mock_stream = MagicMock()
|
||||||
|
mock_current_stream.return_value = mock_stream
|
||||||
|
mock_event_instance = MagicMock()
|
||||||
|
mock_external_event_class.return_value = mock_event_instance
|
||||||
|
|
||||||
|
mock_handle = MagicMock()
|
||||||
|
mock_graph_end.return_value = mock_handle
|
||||||
|
|
||||||
|
workspace = graph_params.workspaces.get(num_tokens)
|
||||||
|
self.assertEqual(workspace, 10)
|
||||||
|
|
||||||
|
weak_ref_tensors = MagicMock(side_effect=lambda x: x)
|
||||||
|
|
||||||
|
# 2. Handle graph capturing mode
|
||||||
|
stream = mock_current_stream()
|
||||||
|
event = mock_external_event_class()
|
||||||
|
event.wait(stream)
|
||||||
|
event.reset(stream)
|
||||||
|
graph_params.events[num_tokens].append(event)
|
||||||
|
graph_params.attn_params[num_tokens].append((
|
||||||
|
weak_ref_tensors(query),
|
||||||
|
weak_ref_tensors(self_obj.key_cache),
|
||||||
|
weak_ref_tensors(self_obj.value_cache),
|
||||||
|
self_obj.num_kv_heads,
|
||||||
|
self_obj.num_heads,
|
||||||
|
self_obj.scale,
|
||||||
|
weak_ref_tensors(attn_metadata.block_tables),
|
||||||
|
attn_metadata.seq_lens,
|
||||||
|
output,
|
||||||
|
))
|
||||||
|
|
||||||
|
mock_event_instance.wait.assert_called_once_with(mock_stream)
|
||||||
|
mock_event_instance.reset.assert_called_once_with(mock_stream)
|
||||||
|
self.assertEqual(len(graph_params.events[num_tokens]), 1)
|
||||||
|
self.assertEqual(len(graph_params.attn_params[num_tokens]), 1)
|
||||||
|
|
||||||
|
query = torch.randn(10, 8 * 64)
|
||||||
|
key = torch.randn(10, 8 * 64)
|
||||||
|
value = torch.randn(10, 8 * 64)
|
||||||
|
kv_cache = torch.empty(2, 5, 128, 8, 64)
|
||||||
|
metadata = self.attn_metadata
|
||||||
|
metadata.attn_state = AscendAttentionState.DecodeOnly
|
||||||
|
metadata.seq_lens = torch.tensor([10])
|
||||||
|
metadata.block_tables = torch.zeros(1, 5, dtype=torch.long)
|
||||||
|
metadata.num_actual_tokens = 10
|
||||||
|
metadata.slot_mapping = torch.zeros(10, dtype=torch.long)
|
||||||
|
layer = self.layer_no_quant
|
||||||
|
|
||||||
|
mock_get_forward_context.return_value = MagicMock(capturing=True)
|
||||||
|
mock_get_graph_params.return_value = graph_params
|
||||||
|
|
||||||
|
output = self.impl.forward(layer,
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
kv_cache,
|
||||||
|
metadata,
|
||||||
|
trace_flag=False)
|
||||||
|
|
||||||
|
mock_paged_attention.assert_called_once()
|
||||||
|
self.assertEqual(len(graph_params.handles[num_tokens]), 0)
|
||||||
|
|
||||||
@patch('torch_npu._npu_reshape_and_cache')
|
@patch('torch_npu._npu_reshape_and_cache')
|
||||||
@patch('torch_npu.npu_fused_infer_attention_score')
|
@patch('torch_npu.npu_fused_infer_attention_score')
|
||||||
def test_forward_decode_only_swa(self, mock_fused_infer_attention_score,
|
def test_forward_decode_only_swa(self, mock_fused_infer_attention_score,
|
||||||
|
@ -33,8 +33,10 @@ from vllm.v1.kv_cache_interface import AttentionSpec
|
|||||||
|
|
||||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||||
maybe_save_kv_layer_to_connector,
|
maybe_save_kv_layer_to_connector,
|
||||||
|
version_check,
|
||||||
wait_for_kv_layer_from_connector)
|
wait_for_kv_layer_from_connector)
|
||||||
from vllm_ascend.compilation.acl_graph import get_graph_params
|
from vllm_ascend.compilation.acl_graph import (get_graph_params,
|
||||||
|
update_graph_params_workspaces)
|
||||||
from vllm_ascend.ops.attention import vanilla_chunked_prefill
|
from vllm_ascend.ops.attention import vanilla_chunked_prefill
|
||||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
|
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
|
||||||
nd_to_nz_2d, nd_to_nz_spec)
|
nd_to_nz_2d, nd_to_nz_spec)
|
||||||
@ -289,6 +291,7 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||||
self.key_cache = None
|
self.key_cache = None
|
||||||
self.value_cache = None
|
self.value_cache = None
|
||||||
|
self.torch_npu_check = version_check()
|
||||||
|
|
||||||
def _forward_prefill_no_cache(
|
def _forward_prefill_no_cache(
|
||||||
self,
|
self,
|
||||||
@ -396,13 +399,29 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
forward_context: ForwardContext = get_forward_context()
|
forward_context: ForwardContext = get_forward_context()
|
||||||
num_tokens = query.shape[0]
|
num_tokens = query.shape[0]
|
||||||
if forward_context.capturing:
|
if forward_context.capturing:
|
||||||
|
if self.torch_npu_check:
|
||||||
|
# Get workspace from cache or calculate it if not present.
|
||||||
|
workspace = graph_params.workspaces.get(num_tokens)
|
||||||
|
if workspace is None:
|
||||||
|
workspace = torch_npu._npu_paged_attention_get_workspace(
|
||||||
|
query=query,
|
||||||
|
key_cache=self.key_cache,
|
||||||
|
value_cache=self.value_cache,
|
||||||
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
scale_value=self.scale,
|
||||||
|
block_table=attn_metadata.block_tables,
|
||||||
|
context_lens=attn_metadata.seq_lens,
|
||||||
|
out=output)
|
||||||
|
update_graph_params_workspaces(num_tokens, workspace)
|
||||||
|
|
||||||
|
# Handle graph capturing mode
|
||||||
stream = torch_npu.npu.current_stream()
|
stream = torch_npu.npu.current_stream()
|
||||||
|
|
||||||
event = torch.npu.ExternalEvent()
|
event = torch.npu.ExternalEvent()
|
||||||
event.wait(stream)
|
event.wait(stream)
|
||||||
event.reset(stream)
|
event.reset(stream)
|
||||||
graph_params.events[num_tokens].append(event)
|
graph_params.events[num_tokens].append(event)
|
||||||
|
|
||||||
graph_params.attn_params[num_tokens].append((
|
graph_params.attn_params[num_tokens].append((
|
||||||
weak_ref_tensors(query),
|
weak_ref_tensors(query),
|
||||||
weak_ref_tensors(self.key_cache),
|
weak_ref_tensors(self.key_cache),
|
||||||
@ -416,16 +435,30 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
|||||||
))
|
))
|
||||||
|
|
||||||
torch.npu.graph_task_group_begin(stream)
|
torch.npu.graph_task_group_begin(stream)
|
||||||
torch_npu._npu_paged_attention(
|
|
||||||
query=query,
|
if self.torch_npu_check:
|
||||||
key_cache=self.key_cache,
|
torch_npu._npu_paged_attention(
|
||||||
value_cache=self.value_cache,
|
query=query,
|
||||||
num_kv_heads=self.num_kv_heads,
|
key_cache=self.key_cache,
|
||||||
num_heads=self.num_heads,
|
value_cache=self.value_cache,
|
||||||
scale_value=self.scale,
|
num_kv_heads=self.num_kv_heads,
|
||||||
block_table=attn_metadata.block_tables,
|
num_heads=self.num_heads,
|
||||||
context_lens=attn_metadata.seq_lens,
|
scale_value=self.scale,
|
||||||
out=output)
|
block_table=attn_metadata.block_tables,
|
||||||
|
context_lens=attn_metadata.seq_lens,
|
||||||
|
out=output,
|
||||||
|
workspace=workspace)
|
||||||
|
else:
|
||||||
|
torch_npu._npu_paged_attention(
|
||||||
|
query=query,
|
||||||
|
key_cache=self.key_cache,
|
||||||
|
value_cache=self.value_cache,
|
||||||
|
num_kv_heads=self.num_kv_heads,
|
||||||
|
num_heads=self.num_heads,
|
||||||
|
scale_value=self.scale,
|
||||||
|
block_table=attn_metadata.block_tables,
|
||||||
|
context_lens=attn_metadata.seq_lens,
|
||||||
|
out=output)
|
||||||
handle = torch.npu.graph_task_group_end(stream)
|
handle = torch.npu.graph_task_group_end(stream)
|
||||||
graph_params.handles[num_tokens].append(handle)
|
graph_params.handles[num_tokens].append(handle)
|
||||||
else:
|
else:
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
|
import functools
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, List
|
from typing import Any, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch_npu
|
||||||
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||||
has_kv_transfer_group,
|
has_kv_transfer_group,
|
||||||
is_v1_kv_transfer_group)
|
is_v1_kv_transfer_group)
|
||||||
@ -139,3 +141,17 @@ def maybe_save_kv_layer_to_connector(
|
|||||||
return
|
return
|
||||||
# TODO: assert ascendMetadata
|
# TODO: assert ascendMetadata
|
||||||
connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata)
|
connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata)
|
||||||
|
|
||||||
|
|
||||||
|
@functools.cache
|
||||||
|
def version_check():
|
||||||
|
import re
|
||||||
|
torch_npu_version = torch_npu.version.__version__
|
||||||
|
date_pattern = r'dev(\d{8})'
|
||||||
|
|
||||||
|
match = re.search(date_pattern, torch_npu_version)
|
||||||
|
if match:
|
||||||
|
full_date = match.group(1)
|
||||||
|
if full_date >= "20250919":
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
@ -18,6 +18,8 @@ from vllm.forward_context import BatchDescriptor, get_forward_context
|
|||||||
from vllm.logger import logger
|
from vllm.logger import logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
from vllm_ascend.attention.utils import version_check
|
||||||
|
|
||||||
from ..utils import weak_ref_tensors
|
from ..utils import weak_ref_tensors
|
||||||
|
|
||||||
|
|
||||||
@ -212,18 +214,32 @@ def update_attn_params(update_stream, forward_context, runtime_shape):
|
|||||||
) = param
|
) = param
|
||||||
# block_table = forward_context.attn_metadata[key].block_tables
|
# block_table = forward_context.attn_metadata[key].block_tables
|
||||||
seq_lens = forward_context.attn_metadata[key].seq_lens
|
seq_lens = forward_context.attn_metadata[key].seq_lens
|
||||||
|
torch_npu_check = version_check()
|
||||||
|
|
||||||
with torch.npu.stream(update_stream):
|
with torch.npu.stream(update_stream):
|
||||||
torch.npu.graph_task_update_begin(update_stream, handle)
|
torch.npu.graph_task_update_begin(update_stream, handle)
|
||||||
torch_npu._npu_paged_attention(query=query,
|
if torch_npu_check:
|
||||||
key_cache=key_cache,
|
torch_npu._npu_paged_attention(
|
||||||
value_cache=value_cache,
|
query=query,
|
||||||
num_kv_heads=num_kv_heads,
|
key_cache=key_cache,
|
||||||
num_heads=num_heads,
|
value_cache=value_cache,
|
||||||
scale_value=scale,
|
num_kv_heads=num_kv_heads,
|
||||||
block_table=block_table,
|
num_heads=num_heads,
|
||||||
context_lens=seq_lens,
|
scale_value=scale,
|
||||||
out=output)
|
block_table=block_table,
|
||||||
|
context_lens=seq_lens,
|
||||||
|
out=output,
|
||||||
|
workspace=graph_params.workspaces.get(runtime_shape))
|
||||||
|
else:
|
||||||
|
torch_npu._npu_paged_attention(query=query,
|
||||||
|
key_cache=key_cache,
|
||||||
|
value_cache=value_cache,
|
||||||
|
num_kv_heads=num_kv_heads,
|
||||||
|
num_heads=num_heads,
|
||||||
|
scale_value=scale,
|
||||||
|
block_table=block_table,
|
||||||
|
context_lens=seq_lens,
|
||||||
|
out=output)
|
||||||
torch.npu.graph_task_update_end(update_stream)
|
torch.npu.graph_task_update_end(update_stream)
|
||||||
|
|
||||||
event.record(update_stream)
|
event.record(update_stream)
|
||||||
@ -302,5 +318,11 @@ def set_graph_params(aclgraph_capture_sizes: set[int]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def update_graph_params_workspaces(num_tokens: int, workspace: int):
|
||||||
|
global _graph_params
|
||||||
|
if _graph_params is not None:
|
||||||
|
_graph_params.workspaces[num_tokens] = workspace
|
||||||
|
|
||||||
|
|
||||||
def get_graph_params():
|
def get_graph_params():
|
||||||
return _graph_params
|
return _graph_params
|
||||||
|
Reference in New Issue
Block a user