From 02c26dcfc7632e90b280a1d20481826b442b9c69 Mon Sep 17 00:00:00 2001 From: xuyexiong Date: Tue, 14 Oct 2025 23:07:45 +0800 Subject: [PATCH] [Feat] Supports Aclgraph for bge-m3 (#3171) ### What this PR does / why we need it? [Feat] Supports Aclgraph for bge-m3 ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? ``` pytest -s tests/e2e/singlecard/test_embedding.py pytest -s tests/e2e/singlecard/test_embedding_aclgraph.py ``` to start an online server with bs 10, each batch's seq length=8192, we set --max-num-batched-tokens=8192*10 to ensure encoder is not chunked: ``` vllm serve /home/data/bge-m3 --max_model_len 1024 --served-model-name "bge-m3" --task embed --host 0.0.0.0 --port 9095 --max-num-batched-tokens 81920 --compilation-config '{"cudagraph_capture_sizes":[8192, 10240, 20480, 40960, 81920]}' ``` For bs10, each batch's seq length=8192, QPS is improved from 85 to 104, which is a 22% improvement, lots of host bound is reduced. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: xuyexiong Co-authored-by: wangyongjun <1104133197@qq.com> --- .github/workflows/_e2e_test.yaml | 2 + tests/e2e/singlecard/test_bge_model.py | 49 +++++++++++ .../e2e/singlecard/test_embedding_aclgraph.py | 55 ++++++++++++ vllm_ascend/attention/attention_mask.py | 9 ++ vllm_ascend/attention/attention_v1.py | 25 ++++-- vllm_ascend/patch/__init__.py | 19 ++++ .../patch/worker/patch_common/__init__.py | 1 + .../patch_common/patch_attention_selector.py | 1 + .../worker/patch_common/patch_roberta.py | 88 +++++++++++++++++++ vllm_ascend/platform.py | 3 +- vllm_ascend/worker/model_runner_v1.py | 76 ++++++++++++---- 11 files changed, 307 insertions(+), 21 deletions(-) create mode 100644 tests/e2e/singlecard/test_bge_model.py create mode 100644 tests/e2e/singlecard/test_embedding_aclgraph.py create mode 100644 vllm_ascend/patch/worker/patch_common/patch_roberta.py diff --git a/.github/workflows/_e2e_test.yaml b/.github/workflows/_e2e_test.yaml index 44f0708ce..276144795 100644 --- a/.github/workflows/_e2e_test.yaml +++ b/.github/workflows/_e2e_test.yaml @@ -90,9 +90,11 @@ jobs: pytest -sv tests/e2e/singlecard/test_aclgraph.py pytest -sv tests/e2e/singlecard/test_ascend_scheduler.py + pytest -sv tests/e2e/singlecard/test_bge_model.py pytest -sv tests/e2e/singlecard/test_camem.py pytest -sv tests/e2e/singlecard/test_chunked.py pytest -sv tests/e2e/singlecard/test_embedding.py + pytest -sv tests/e2e/singlecard/test_embedding_aclgraph.py pytest -sv tests/e2e/singlecard/test_guided_decoding.py pytest -sv tests/e2e/singlecard/test_ilama_lora.py pytest -sv tests/e2e/singlecard/test_profile_execute_duration.py diff --git a/tests/e2e/singlecard/test_bge_model.py b/tests/e2e/singlecard/test_bge_model.py new file mode 100644 index 000000000..968bf1c7d --- /dev/null +++ b/tests/e2e/singlecard/test_bge_model.py @@ -0,0 +1,49 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py +# +from modelscope import snapshot_download # type: ignore[import-untyped] + +from tests.e2e.conftest import HfRunner, VllmRunner +from tests.e2e.utils import check_embeddings_close + + +def test_bge_model_correctness(): + queries = ['What is the capital of China?', 'Explain gravity'] + + model_name = snapshot_download("BAAI/bge-m3") + with VllmRunner( + model_name, + task="embed", + enforce_eager=True, + ) as vllm_runner: + vllm_outputs = vllm_runner.encode(queries) + + with HfRunner( + model_name, + dtype="float32", + is_sentence_transformer=True, + ) as hf_runner: + hf_outputs = hf_runner.encode(queries) + + check_embeddings_close( + embeddings_0_lst=hf_outputs, + embeddings_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + tol=1e-2, + ) diff --git a/tests/e2e/singlecard/test_embedding_aclgraph.py b/tests/e2e/singlecard/test_embedding_aclgraph.py new file mode 100644 index 000000000..e0851b064 --- /dev/null +++ b/tests/e2e/singlecard/test_embedding_aclgraph.py @@ -0,0 +1,55 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py +# +import os + +import pytest + +from tests.e2e.conftest import VllmRunner +from tests.e2e.utils import check_embeddings_close + +os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + +MODELS = ["BAAI/bge-m3"] + + +@pytest.mark.parametrize("model_name", MODELS) +def test_aclgrpah_embed_models_correctness(model_name): + queries = ['What is the capital of China?', 'Explain gravity'] + + with VllmRunner( + model_name, + task="embed", + enforce_eager=False, + ) as vllm_aclgraph_runner: + vllm_aclgraph_outputs = vllm_aclgraph_runner.encode(queries) + + with VllmRunner( + model_name, + task="embed", + enforce_eager=True, + ) as vllm_runner: + vllm_outputs = vllm_runner.encode(queries) + + check_embeddings_close( + embeddings_0_lst=vllm_outputs, + embeddings_1_lst=vllm_aclgraph_outputs, + name_0="hf", + name_1="vllm", + tol=1e-2, + ) diff --git a/vllm_ascend/attention/attention_mask.py b/vllm_ascend/attention/attention_mask.py index 225d4b903..079efff4d 100644 --- a/vllm_ascend/attention/attention_mask.py +++ b/vllm_ascend/attention/attention_mask.py @@ -50,6 +50,7 @@ class AttentionMaskBuilder: self._seq_len_cached = attn_mask.shape[0] self.attn_mask_cache = attn_mask self.device = device + self.pooling_mask = None if torch.version.cann.startswith("8.3"): assigned_mask_dim = 2048 self.chunked_prefill_attn_mask = torch.triu( @@ -75,6 +76,14 @@ class AttentionMaskBuilder: return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous( ).to(device, non_blocking=True) + def get_pooling_mask(self, device): + if self.pooling_mask is None: + # the compressed attention mask for npu_fusion_attention sparse mode 4 + self.pooling_mask = torch.triu(torch.ones( + 2048, 2048), diagonal=1).to(torch.bool).to(device, + non_blocking=True) + return self.pooling_mask + def get_splitfuse_attn_mask( self, seq_lens: torch.Tensor = None, diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 561ee5dd3..53c93a61d 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -606,9 +606,8 @@ class AscendAttentionBackendImpl(AttentionImpl): num_actual_tokens = attn_metadata.num_actual_tokens assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 attn_type = self.attn_type - if attn_type != AttentionType.DECODER: - raise NotImplementedError("Encoder self-attention and " - "encoder/decoder cross-attention " + if attn_type != AttentionType.DECODER and attn_type != AttentionType.ENCODER_ONLY: + raise NotImplementedError("Encoder/decoder cross-attention " "are not implemented for " "PallasAttentionBackendImpl") # View q k v to BSH. @@ -628,9 +627,25 @@ class AscendAttentionBackendImpl(AttentionImpl): key_cache=self.key_cache, value_cache=self.value_cache, slot_indices=slots) - + if attn_type == AttentionType.ENCODER_ONLY: + cum_seq_len = attn_metadata.query_start_loc[1:].tolist() + attn_out = torch_npu.npu_fusion_attention( + query, + key, + value, + head_num=self.num_heads, + input_layout="TND", + scale=self.scale, + sparse_mode=4, + atten_mask=attn_metadata.attn_mask, + pre_tockens=attn_metadata.max_query_len, + next_tockens=attn_metadata.max_query_len, + actual_seq_qlen=cum_seq_len, + actual_seq_kvlen=cum_seq_len, + ) + output = attn_out[0] # V0-Style scheduler situation. - if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: + elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: output = self._forward_prefill_no_cache( query, key, value, attn_metadata, output, num_tokens) elif attn_metadata.attn_state == \ diff --git a/vllm_ascend/patch/__init__.py b/vllm_ascend/patch/__init__.py index 6cb9004d8..f76d8810c 100644 --- a/vllm_ascend/patch/__init__.py +++ b/vllm_ascend/patch/__init__.py @@ -132,3 +132,22 @@ # - this is a bug by Ascend only. It can' be fixed in vLLM. # Future Plan: # Fix this bug in torch-npu, bump torch-npu version and remove this patch. +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# 1. `vllm.model_executor.models.roberta.RobertaEmbedding.forward` +# Why: +# shift operation in `_encode_token_type_ids` and `_decode_token_type_ids` cannot run in ascend aclgraph mode +# How: +# Replace shift operation with multiplication and division. +# Related PR (if no, explain why): +# No, this need CANN add an aclnn shift operation +# Future Plan: +# Revert this when CANN support shift aclnn operation +# 2. `vllm.model_executor.models.roberta.RobertaForSequenceClassification.forward ` +# Why: +# shift operation in `_encode_token_type_ids` and `_decode_token_type_ids` cannot run in ascend aclgraph mode +# How: +# Replace shift operation with multiplication and division. +# Related PR (if no, explain why): +# No, this need CANN add an aclnn shift operation +# Future Plan: +# Revert this when CANN support shift aclnn operation diff --git a/vllm_ascend/patch/worker/patch_common/__init__.py b/vllm_ascend/patch/worker/patch_common/__init__.py index c9bea7158..99ec34561 100644 --- a/vllm_ascend/patch/worker/patch_common/__init__.py +++ b/vllm_ascend/patch/worker/patch_common/__init__.py @@ -25,6 +25,7 @@ import vllm_ascend.patch.worker.patch_common.patch_attention_selector # noqa import vllm_ascend.patch.worker.patch_common.patch_attention_layer # noqa import vllm_ascend.patch.worker.patch_common.patch_distributed # noqa import vllm_ascend.patch.worker.patch_common.patch_logits # noqa +import vllm_ascend.patch.worker.patch_common.patch_roberta # noqa import vllm_ascend.patch.worker.patch_common.patch_weight_loader # noqa import vllm_ascend.patch.worker.patch_common.patch_multimodal_merge # noqa diff --git a/vllm_ascend/patch/worker/patch_common/patch_attention_selector.py b/vllm_ascend/patch/worker/patch_common/patch_attention_selector.py index b456e6630..3bea9d461 100644 --- a/vllm_ascend/patch/worker/patch_common/patch_attention_selector.py +++ b/vllm_ascend/patch/worker/patch_common/patch_attention_selector.py @@ -64,6 +64,7 @@ def _cached_get_attn_backend( use_mla: bool = False, use_sfa: bool = False, has_sink: bool = False, + use_sparse: bool = False, ) -> type[AttentionBackend]: # Check whether a particular choice of backend was # previously forced. diff --git a/vllm_ascend/patch/worker/patch_common/patch_roberta.py b/vllm_ascend/patch/worker/patch_common/patch_roberta.py new file mode 100644 index 000000000..9c9f5e89d --- /dev/null +++ b/vllm_ascend/patch/worker/patch_common/patch_roberta.py @@ -0,0 +1,88 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import Optional + +import torch +from vllm.model_executor.models.roberta import ( + RobertaEmbedding, RobertaForSequenceClassification, + replace_roberta_positions) +from vllm.sequence import IntermediateTensors + +# aclgraph does not support shift operator for now +# TODO: revert me when aclgraph supports shift operator +TOKEN_TYPE_SHIFT = 30 +TOKEN_TYPE_MULTIPLIER = 1 << 30 +TOKEN_MASK = TOKEN_TYPE_MULTIPLIER - 1 + + +def _encode_token_type_ids(input_ids: torch.Tensor, + token_type_ids: torch.Tensor) -> None: + # input_ids can be padded to the right + input_ids[:token_type_ids.shape[0]].bitwise_or_(token_type_ids * + TOKEN_TYPE_MULTIPLIER) + + +def _decode_token_type_ids(input_ids: torch.Tensor) -> torch.Tensor: + + token_type_ids = input_ids // TOKEN_TYPE_MULTIPLIER + + input_ids.bitwise_and_(TOKEN_MASK) + + return token_type_ids + + +def roberta_for_sequence_classification_forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, +) -> torch.Tensor: + replace_roberta_positions(input_ids=input_ids, + position_ids=positions, + padding_idx=self.padding_idx) + if token_type_ids is not None: + assert self.roberta.config.vocab_size < (1 << TOKEN_TYPE_SHIFT) + assert input_ids is not None + _encode_token_type_ids(input_ids, token_type_ids) + return self.roberta(input_ids=input_ids, + positions=positions, + inputs_embeds=inputs_embeds, + intermediate_tensors=intermediate_tensors) + + +def roberta_embedding_forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, +) -> torch.Tensor: + + token_type_ids = _decode_token_type_ids(input_ids) + + inputs_embeds = self.word_embeddings(input_ids) + position_embeddings = self.position_embeddings(position_ids) + + token_type_embeddings = self.token_type_embeddings(token_type_ids) + embeddings = inputs_embeds + token_type_embeddings + position_embeddings + embeddings = self.LayerNorm(embeddings) + return embeddings + + +RobertaEmbedding.forward = roberta_embedding_forward +RobertaForSequenceClassification.forward = roberta_for_sequence_classification_forward diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index a90a73e1b..b00e8d60b 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -134,7 +134,8 @@ class NPUPlatform(Platform): structured_outputs_config = vllm_config.structured_outputs_config if (model_config is not None and not model_config.use_mla - and not scheduler_config.async_scheduling): + and not scheduler_config.async_scheduling + and model_config.runner_type != "pooling"): logger.info( "Non-MLA LLMs forcibly disable the chunked prefill feature," "as the performance of operators supporting this feature " diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index b432db063..3253492b2 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -77,10 +77,11 @@ from vllm.v1.attention.backends.utils import ( from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher # yapf conflicts with isort for this block # yapf: disable -from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec, - KVCacheConfig, KVCacheGroupSpec, - KVCacheSpec, MambaSpec, - MLAAttentionSpec, +from vllm.v1.kv_cache_interface import (AttentionSpec, + EncoderOnlyAttentionSpec, + FullAttentionSpec, KVCacheConfig, + KVCacheGroupSpec, KVCacheSpec, + MambaSpec, MLAAttentionSpec, UniformTypeKVCacheSpecs) # yapf: enable from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, @@ -867,8 +868,11 @@ class NPUModelRunner(LoRAModelRunnerMixin): def _make_attention_mask(self, seq_lens, position, attn_state) -> torch.Tensor: + # Pooling situation. + if self.model_config.runner_type == "pooling" and self.model_config.pooler_config.pooling_type == "CLS": + return self.attn_mask_builder.get_pooling_mask(self.device) # Chunk Prefill situation. - if attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla and not self.ascend_config.use_sfa: + elif attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla and not self.ascend_config.use_sfa: if torch.version.cann.startswith("8.3"): return self.attn_mask_builder.get_splitfuse_attn_mask() else: @@ -1426,14 +1430,29 @@ class NPUModelRunner(LoRAModelRunnerMixin): # in the same group share the same metadata. for kv_cache_group_id, kv_cache_group_spec in enumerate( self.kv_cache_config.kv_cache_groups): - blk_table = self.input_batch.block_table[kv_cache_group_id] - blk_table_tensor = blk_table.get_device_tensor() - slot_mapping = blk_table.slot_mapping_cpu[: - total_num_scheduled_tokens] - self.slot_mapping[:total_num_scheduled_tokens].copy_( - slot_mapping[:total_num_scheduled_tokens], - non_blocking=True, - ) + if isinstance(kv_cache_group_spec.kv_cache_spec, + EncoderOnlyAttentionSpec): + # Encoder-only layers do not have KV cache, so we need to + # create a dummy block table and slot mapping for them. + blk_table_tensor = torch.zeros( + (num_reqs, 1), + dtype=torch.int32, + device=self.device, + ) + slot_mapping = torch.zeros( + (total_num_scheduled_tokens, ), + dtype=torch.int64, + device=self.device, + ) + else: + blk_table = self.input_batch.block_table[kv_cache_group_id] + blk_table_tensor = blk_table.get_device_tensor() + slot_mapping = blk_table.slot_mapping_cpu[: + total_num_scheduled_tokens] + self.slot_mapping[:total_num_scheduled_tokens].copy_( + slot_mapping[:total_num_scheduled_tokens], + non_blocking=True, + ) # Make AscendCommonAttentionMetadata common_attn_metadata = AscendCommonAttentionMetadata( @@ -1469,7 +1488,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): common_prefix_len = 0 extra_attn_metadata_args = {} builder = attn_group.get_metadata_builder() - if isinstance(builder, GDNAttentionMetadataBuilder): + if isinstance(builder, GDNAttentionMetadataBuilder + ) or self.model_config.runner_type == "pooling": if use_spec_decode: extra_attn_metadata_args = dict( num_accepted_tokens=self.num_accepted_tokens. @@ -2625,7 +2645,6 @@ class NPUModelRunner(LoRAModelRunnerMixin): """ kv_cache_config = deepcopy(kv_cache_config) self.kv_cache_config = kv_cache_config - self.initialize_attn_backend(kv_cache_config) self.use_hybrid_blocks = (len(self.attn_groups) > 1) # NOTE: Currently, we determine whether we need `num_accepted_tokens` through `MambaSpec`. self.need_accepted_tokens = any([ @@ -2634,6 +2653,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): ]) self.may_reinitialize_input_batch(kv_cache_config) + self.may_add_encoder_only_layers_to_kv_cache_config() + self.initialize_attn_backend(kv_cache_config) if self.ascend_config.is_deepseek_sfa: kv_caches = self.initialize_kv_cache_tensors_deepseek_sfa( @@ -3089,6 +3110,31 @@ class NPUModelRunner(LoRAModelRunnerMixin): kernel_block_sizes=kernel_block_sizes, ) + def may_add_encoder_only_layers_to_kv_cache_config(self) -> None: + """ + Add encoder-only layers to the KV cache config. + """ + block_size = self.vllm_config.cache_config.block_size + encoder_only_attn_specs: dict[AttentionSpec, + list[str]] = defaultdict(list) + attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) + for layer_name, attn_module in attn_layers.items(): + if attn_module.attn_type == AttentionType.ENCODER_ONLY: + attn_spec: AttentionSpec = EncoderOnlyAttentionSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=self.kv_cache_dtype) + encoder_only_attn_specs[attn_spec].append(layer_name) + self.runner_only_attn_layers.add(layer_name) + if len(encoder_only_attn_specs) > 0: + assert len( + encoder_only_attn_specs + ) == 1, "Only support one encoder-only attention spec now" + spec, layer_names = encoder_only_attn_specs.popitem() + self.kv_cache_config.kv_cache_groups.append( + KVCacheGroupSpec(layer_names=layer_names, kv_cache_spec=spec)) + def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: """ Initialize the attention backends and attention metadata builders.