diff --git a/tests/e2e/multicard/test_full_graph_mode.py b/tests/e2e/multicard/test_full_graph_mode.py new file mode 100644 index 000000000..3b9f29323 --- /dev/null +++ b/tests/e2e/multicard/test_full_graph_mode.py @@ -0,0 +1,72 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py +# +"""Compare the short outputs of HF and vLLM when using greedy sampling. + +Run `pytest tests/e2e/multicard/test_qwen3_moe.py`. +""" + +import os + +from vllm import SamplingParams + +from tests.e2e.conftest import VllmRunner +from tests.e2e.model_utils import check_outputs_equal + + +def test_models_distributed_Qwen3_MOE_TP2_WITH_FULLGRAPH(): + if 'HCCL_OP_EXPANSION_MODE' in os.environ: + del os.environ['HCCL_OP_EXPANSION_MODE'] + prompts = [ + "Hello, my name is", "The president of the United States is", + "The capital of France is", "The future of AI is" + ] + model = "Qwen/Qwen3-30B-A3B" + sampling_params = SamplingParams(max_tokens=32, temperature=0.0) + with VllmRunner(model, + max_model_len=1024, + tensor_parallel_size=2, + enforce_eager=False, + compilation_config={"cudagraph_mode": + "FULL_DECODE_ONLY"}) as runner: + vllm_fullgraph_outputs = runner.model.generate(prompts, + sampling_params) + + with VllmRunner( + model, + max_model_len=1024, + enforce_eager=True, + ) as runner: + vllm_eager_outputs = runner.model.generate(prompts, sampling_params) + + vllm_fullgraph_outputs_list = [] + for output in vllm_fullgraph_outputs: + vllm_fullgraph_outputs_list.append( + (output.outputs[0].index, output.outputs[0].text)) + + vllm_eager_outputs_list = [] + for output in vllm_eager_outputs: + vllm_eager_outputs_list.append( + (output.outputs[0].index, output.outputs[0].text)) + + check_outputs_equal( + outputs_0_lst=vllm_eager_outputs_list, + outputs_1_lst=vllm_fullgraph_outputs_list, + name_0="vllm_eager_outputs", + name_1="vllm_fullgraph_outputs", + ) diff --git a/tests/ut/attention/test_attention_v1.py b/tests/ut/attention/test_attention_v1.py index d553637d9..e95db1a93 100644 --- a/tests/ut/attention/test_attention_v1.py +++ b/tests/ut/attention/test_attention_v1.py @@ -405,6 +405,113 @@ class TestAscendAttentionBackendImpl(TestBase): mock_paged_attention.assert_called_once() assert output.shape == (10, 8 * 64) + @patch('vllm_ascend.attention.attention_v1.get_forward_context') + @patch('vllm_ascend.attention.attention_v1.get_graph_params') + @patch('torch_npu._npu_reshape_and_cache') + @patch('torch_npu._npu_paged_attention') + @patch('torch.npu.graph_task_group_end') + @patch('torch.npu.graph_task_group_begin') + @patch('torch.npu.ExternalEvent') + @patch('torch_npu.npu.current_stream') + @patch('vllm_ascend.attention.attention_v1.weak_ref_tensors') + def test_paged_attention_with_existing_workspace( + self, + mock_get_forward_context, + mock_get_graph_params, + mock_npu_reshape_and_cache, + mock_paged_attention, + mock_graph_begin, + mock_graph_end, + mock_external_event_class, + mock_current_stream, + mock_weak_ref_tensors, + ): + graph_params = MagicMock() + attn_metadata = MagicMock() + num_tokens = 10 + + graph_params.workspaces = {num_tokens: 10} + graph_params.events = {num_tokens: []} + graph_params.attn_params = {num_tokens: []} + graph_params.handles = {num_tokens: []} + + query = torch.randn(2, 5, 8) # [batch_size, seq_len, hidden_size] + key_cache = MagicMock() + value_cache = MagicMock() + num_kv_heads = 4 + num_heads = 8 + scale = 0.1 + output = torch.randn(2, 5, 8) + + self_obj = MagicMock() + self_obj.key_cache = key_cache + self_obj.value_cache = value_cache + self_obj.num_kv_heads = num_kv_heads + self_obj.num_heads = num_heads + self_obj.scale = scale + + mock_stream = MagicMock() + mock_current_stream.return_value = mock_stream + mock_event_instance = MagicMock() + mock_external_event_class.return_value = mock_event_instance + + mock_handle = MagicMock() + mock_graph_end.return_value = mock_handle + + workspace = graph_params.workspaces.get(num_tokens) + self.assertEqual(workspace, 10) + + weak_ref_tensors = MagicMock(side_effect=lambda x: x) + + # 2. Handle graph capturing mode + stream = mock_current_stream() + event = mock_external_event_class() + event.wait(stream) + event.reset(stream) + graph_params.events[num_tokens].append(event) + graph_params.attn_params[num_tokens].append(( + weak_ref_tensors(query), + weak_ref_tensors(self_obj.key_cache), + weak_ref_tensors(self_obj.value_cache), + self_obj.num_kv_heads, + self_obj.num_heads, + self_obj.scale, + weak_ref_tensors(attn_metadata.block_tables), + attn_metadata.seq_lens, + output, + )) + + mock_event_instance.wait.assert_called_once_with(mock_stream) + mock_event_instance.reset.assert_called_once_with(mock_stream) + self.assertEqual(len(graph_params.events[num_tokens]), 1) + self.assertEqual(len(graph_params.attn_params[num_tokens]), 1) + + query = torch.randn(10, 8 * 64) + key = torch.randn(10, 8 * 64) + value = torch.randn(10, 8 * 64) + kv_cache = torch.empty(2, 5, 128, 8, 64) + metadata = self.attn_metadata + metadata.attn_state = AscendAttentionState.DecodeOnly + metadata.seq_lens = torch.tensor([10]) + metadata.block_tables = torch.zeros(1, 5, dtype=torch.long) + metadata.num_actual_tokens = 10 + metadata.slot_mapping = torch.zeros(10, dtype=torch.long) + layer = self.layer_no_quant + + mock_get_forward_context.return_value = MagicMock(capturing=True) + mock_get_graph_params.return_value = graph_params + + output = self.impl.forward(layer, + query, + key, + value, + kv_cache, + metadata, + trace_flag=False) + + mock_paged_attention.assert_called_once() + self.assertEqual(len(graph_params.handles[num_tokens]), 0) + @patch('torch_npu._npu_reshape_and_cache') @patch('torch_npu.npu_fused_infer_attention_score') def test_forward_decode_only_swa(self, mock_fused_infer_attention_score, diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index bf881b1bf..561ee5dd3 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -33,8 +33,10 @@ from vllm.v1.kv_cache_interface import AttentionSpec from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, maybe_save_kv_layer_to_connector, + version_check, wait_for_kv_layer_from_connector) -from vllm_ascend.compilation.acl_graph import get_graph_params +from vllm_ascend.compilation.acl_graph import (get_graph_params, + update_graph_params_workspaces) from vllm_ascend.ops.attention import vanilla_chunked_prefill from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p, nd_to_nz_2d, nd_to_nz_spec) @@ -289,6 +291,7 @@ class AscendAttentionBackendImpl(AttentionImpl): self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.key_cache = None self.value_cache = None + self.torch_npu_check = version_check() def _forward_prefill_no_cache( self, @@ -396,13 +399,29 @@ class AscendAttentionBackendImpl(AttentionImpl): forward_context: ForwardContext = get_forward_context() num_tokens = query.shape[0] if forward_context.capturing: + if self.torch_npu_check: + # Get workspace from cache or calculate it if not present. + workspace = graph_params.workspaces.get(num_tokens) + if workspace is None: + workspace = torch_npu._npu_paged_attention_get_workspace( + query=query, + key_cache=self.key_cache, + value_cache=self.value_cache, + num_kv_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale_value=self.scale, + block_table=attn_metadata.block_tables, + context_lens=attn_metadata.seq_lens, + out=output) + update_graph_params_workspaces(num_tokens, workspace) + + # Handle graph capturing mode stream = torch_npu.npu.current_stream() event = torch.npu.ExternalEvent() event.wait(stream) event.reset(stream) graph_params.events[num_tokens].append(event) - graph_params.attn_params[num_tokens].append(( weak_ref_tensors(query), weak_ref_tensors(self.key_cache), @@ -416,16 +435,30 @@ class AscendAttentionBackendImpl(AttentionImpl): )) torch.npu.graph_task_group_begin(stream) - torch_npu._npu_paged_attention( - query=query, - key_cache=self.key_cache, - value_cache=self.value_cache, - num_kv_heads=self.num_kv_heads, - num_heads=self.num_heads, - scale_value=self.scale, - block_table=attn_metadata.block_tables, - context_lens=attn_metadata.seq_lens, - out=output) + + if self.torch_npu_check: + torch_npu._npu_paged_attention( + query=query, + key_cache=self.key_cache, + value_cache=self.value_cache, + num_kv_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale_value=self.scale, + block_table=attn_metadata.block_tables, + context_lens=attn_metadata.seq_lens, + out=output, + workspace=workspace) + else: + torch_npu._npu_paged_attention( + query=query, + key_cache=self.key_cache, + value_cache=self.value_cache, + num_kv_heads=self.num_kv_heads, + num_heads=self.num_heads, + scale_value=self.scale, + block_table=attn_metadata.block_tables, + context_lens=attn_metadata.seq_lens, + out=output) handle = torch.npu.graph_task_group_end(stream) graph_params.handles[num_tokens].append(handle) else: diff --git a/vllm_ascend/attention/utils.py b/vllm_ascend/attention/utils.py index 271ff733e..cc4712251 100644 --- a/vllm_ascend/attention/utils.py +++ b/vllm_ascend/attention/utils.py @@ -1,7 +1,9 @@ +import functools from dataclasses import dataclass from typing import Any, List import torch +import torch_npu from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group, is_v1_kv_transfer_group) @@ -139,3 +141,17 @@ def maybe_save_kv_layer_to_connector( return # TODO: assert ascendMetadata connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata) + + +@functools.cache +def version_check(): + import re + torch_npu_version = torch_npu.version.__version__ + date_pattern = r'dev(\d{8})' + + match = re.search(date_pattern, torch_npu_version) + if match: + full_date = match.group(1) + if full_date >= "20250919": + return True + return False diff --git a/vllm_ascend/compilation/acl_graph.py b/vllm_ascend/compilation/acl_graph.py index e5f5ae71d..0065dd4af 100644 --- a/vllm_ascend/compilation/acl_graph.py +++ b/vllm_ascend/compilation/acl_graph.py @@ -18,6 +18,8 @@ from vllm.forward_context import BatchDescriptor, get_forward_context from vllm.logger import logger from vllm.platforms import current_platform +from vllm_ascend.attention.utils import version_check + from ..utils import weak_ref_tensors @@ -212,18 +214,32 @@ def update_attn_params(update_stream, forward_context, runtime_shape): ) = param # block_table = forward_context.attn_metadata[key].block_tables seq_lens = forward_context.attn_metadata[key].seq_lens + torch_npu_check = version_check() with torch.npu.stream(update_stream): torch.npu.graph_task_update_begin(update_stream, handle) - torch_npu._npu_paged_attention(query=query, - key_cache=key_cache, - value_cache=value_cache, - num_kv_heads=num_kv_heads, - num_heads=num_heads, - scale_value=scale, - block_table=block_table, - context_lens=seq_lens, - out=output) + if torch_npu_check: + torch_npu._npu_paged_attention( + query=query, + key_cache=key_cache, + value_cache=value_cache, + num_kv_heads=num_kv_heads, + num_heads=num_heads, + scale_value=scale, + block_table=block_table, + context_lens=seq_lens, + out=output, + workspace=graph_params.workspaces.get(runtime_shape)) + else: + torch_npu._npu_paged_attention(query=query, + key_cache=key_cache, + value_cache=value_cache, + num_kv_heads=num_kv_heads, + num_heads=num_heads, + scale_value=scale, + block_table=block_table, + context_lens=seq_lens, + out=output) torch.npu.graph_task_update_end(update_stream) event.record(update_stream) @@ -302,5 +318,11 @@ def set_graph_params(aclgraph_capture_sizes: set[int]): ) +def update_graph_params_workspaces(num_tokens: int, workspace: int): + global _graph_params + if _graph_params is not None: + _graph_params.workspaces[num_tokens] = workspace + + def get_graph_params(): return _graph_params