mirror of
https://github.com/vllm-project/vllm-ascend.git
synced 2025-10-20 13:43:53 +08:00
[Feat] Supports Aclgraph for bge-m3 (#3171)
### What this PR does / why we need it? [Feat] Supports Aclgraph for bge-m3 ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? ``` pytest -s tests/e2e/singlecard/test_embedding.py pytest -s tests/e2e/singlecard/test_embedding_aclgraph.py ``` to start an online server with bs 10, each batch's seq length=8192, we set --max-num-batched-tokens=8192*10 to ensure encoder is not chunked: ``` vllm serve /home/data/bge-m3 --max_model_len 1024 --served-model-name "bge-m3" --task embed --host 0.0.0.0 --port 9095 --max-num-batched-tokens 81920 --compilation-config '{"cudagraph_capture_sizes":[8192, 10240, 20480, 40960, 81920]}' ``` For bs10, each batch's seq length=8192, QPS is improved from 85 to 104, which is a 22% improvement, lots of host bound is reduced. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: xuyexiong <xuyexiong@huawei.com> Co-authored-by: wangyongjun <1104133197@qq.com>
This commit is contained in:
2
.github/workflows/_e2e_test.yaml
vendored
2
.github/workflows/_e2e_test.yaml
vendored
@ -90,9 +90,11 @@ jobs:
|
||||
|
||||
pytest -sv tests/e2e/singlecard/test_aclgraph.py
|
||||
pytest -sv tests/e2e/singlecard/test_ascend_scheduler.py
|
||||
pytest -sv tests/e2e/singlecard/test_bge_model.py
|
||||
pytest -sv tests/e2e/singlecard/test_camem.py
|
||||
pytest -sv tests/e2e/singlecard/test_chunked.py
|
||||
pytest -sv tests/e2e/singlecard/test_embedding.py
|
||||
pytest -sv tests/e2e/singlecard/test_embedding_aclgraph.py
|
||||
pytest -sv tests/e2e/singlecard/test_guided_decoding.py
|
||||
pytest -sv tests/e2e/singlecard/test_ilama_lora.py
|
||||
pytest -sv tests/e2e/singlecard/test_profile_execute_duration.py
|
||||
|
49
tests/e2e/singlecard/test_bge_model.py
Normal file
49
tests/e2e/singlecard/test_bge_model.py
Normal file
@ -0,0 +1,49 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py
|
||||
#
|
||||
from modelscope import snapshot_download # type: ignore[import-untyped]
|
||||
|
||||
from tests.e2e.conftest import HfRunner, VllmRunner
|
||||
from tests.e2e.utils import check_embeddings_close
|
||||
|
||||
|
||||
def test_bge_model_correctness():
|
||||
queries = ['What is the capital of China?', 'Explain gravity']
|
||||
|
||||
model_name = snapshot_download("BAAI/bge-m3")
|
||||
with VllmRunner(
|
||||
model_name,
|
||||
task="embed",
|
||||
enforce_eager=True,
|
||||
) as vllm_runner:
|
||||
vllm_outputs = vllm_runner.encode(queries)
|
||||
|
||||
with HfRunner(
|
||||
model_name,
|
||||
dtype="float32",
|
||||
is_sentence_transformer=True,
|
||||
) as hf_runner:
|
||||
hf_outputs = hf_runner.encode(queries)
|
||||
|
||||
check_embeddings_close(
|
||||
embeddings_0_lst=hf_outputs,
|
||||
embeddings_1_lst=vllm_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
tol=1e-2,
|
||||
)
|
55
tests/e2e/singlecard/test_embedding_aclgraph.py
Normal file
55
tests/e2e/singlecard/test_embedding_aclgraph.py
Normal file
@ -0,0 +1,55 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py
|
||||
#
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.e2e.conftest import VllmRunner
|
||||
from tests.e2e.utils import check_embeddings_close
|
||||
|
||||
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
||||
|
||||
MODELS = ["BAAI/bge-m3"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", MODELS)
|
||||
def test_aclgrpah_embed_models_correctness(model_name):
|
||||
queries = ['What is the capital of China?', 'Explain gravity']
|
||||
|
||||
with VllmRunner(
|
||||
model_name,
|
||||
task="embed",
|
||||
enforce_eager=False,
|
||||
) as vllm_aclgraph_runner:
|
||||
vllm_aclgraph_outputs = vllm_aclgraph_runner.encode(queries)
|
||||
|
||||
with VllmRunner(
|
||||
model_name,
|
||||
task="embed",
|
||||
enforce_eager=True,
|
||||
) as vllm_runner:
|
||||
vllm_outputs = vllm_runner.encode(queries)
|
||||
|
||||
check_embeddings_close(
|
||||
embeddings_0_lst=vllm_outputs,
|
||||
embeddings_1_lst=vllm_aclgraph_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
tol=1e-2,
|
||||
)
|
@ -50,6 +50,7 @@ class AttentionMaskBuilder:
|
||||
self._seq_len_cached = attn_mask.shape[0]
|
||||
self.attn_mask_cache = attn_mask
|
||||
self.device = device
|
||||
self.pooling_mask = None
|
||||
if torch.version.cann.startswith("8.3"):
|
||||
assigned_mask_dim = 2048
|
||||
self.chunked_prefill_attn_mask = torch.triu(
|
||||
@ -75,6 +76,14 @@ class AttentionMaskBuilder:
|
||||
return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous(
|
||||
).to(device, non_blocking=True)
|
||||
|
||||
def get_pooling_mask(self, device):
|
||||
if self.pooling_mask is None:
|
||||
# the compressed attention mask for npu_fusion_attention sparse mode 4
|
||||
self.pooling_mask = torch.triu(torch.ones(
|
||||
2048, 2048), diagonal=1).to(torch.bool).to(device,
|
||||
non_blocking=True)
|
||||
return self.pooling_mask
|
||||
|
||||
def get_splitfuse_attn_mask(
|
||||
self,
|
||||
seq_lens: torch.Tensor = None,
|
||||
|
@ -606,9 +606,8 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
||||
attn_type = self.attn_type
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
if attn_type != AttentionType.DECODER and attn_type != AttentionType.ENCODER_ONLY:
|
||||
raise NotImplementedError("Encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"PallasAttentionBackendImpl")
|
||||
# View q k v to BSH.
|
||||
@ -628,9 +627,25 @@ class AscendAttentionBackendImpl(AttentionImpl):
|
||||
key_cache=self.key_cache,
|
||||
value_cache=self.value_cache,
|
||||
slot_indices=slots)
|
||||
|
||||
if attn_type == AttentionType.ENCODER_ONLY:
|
||||
cum_seq_len = attn_metadata.query_start_loc[1:].tolist()
|
||||
attn_out = torch_npu.npu_fusion_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
head_num=self.num_heads,
|
||||
input_layout="TND",
|
||||
scale=self.scale,
|
||||
sparse_mode=4,
|
||||
atten_mask=attn_metadata.attn_mask,
|
||||
pre_tockens=attn_metadata.max_query_len,
|
||||
next_tockens=attn_metadata.max_query_len,
|
||||
actual_seq_qlen=cum_seq_len,
|
||||
actual_seq_kvlen=cum_seq_len,
|
||||
)
|
||||
output = attn_out[0]
|
||||
# V0-Style scheduler situation.
|
||||
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
||||
elif attn_metadata.attn_state == AscendAttentionState.PrefillNoCache:
|
||||
output = self._forward_prefill_no_cache(
|
||||
query, key, value, attn_metadata, output, num_tokens)
|
||||
elif attn_metadata.attn_state == \
|
||||
|
@ -132,3 +132,22 @@
|
||||
# - this is a bug by Ascend only. It can' be fixed in vLLM.
|
||||
# Future Plan:
|
||||
# Fix this bug in torch-npu, bump torch-npu version and remove this patch.
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# 1. `vllm.model_executor.models.roberta.RobertaEmbedding.forward`
|
||||
# Why:
|
||||
# shift operation in `_encode_token_type_ids` and `_decode_token_type_ids` cannot run in ascend aclgraph mode
|
||||
# How:
|
||||
# Replace shift operation with multiplication and division.
|
||||
# Related PR (if no, explain why):
|
||||
# No, this need CANN add an aclnn shift operation
|
||||
# Future Plan:
|
||||
# Revert this when CANN support shift aclnn operation
|
||||
# 2. `vllm.model_executor.models.roberta.RobertaForSequenceClassification.forward `
|
||||
# Why:
|
||||
# shift operation in `_encode_token_type_ids` and `_decode_token_type_ids` cannot run in ascend aclgraph mode
|
||||
# How:
|
||||
# Replace shift operation with multiplication and division.
|
||||
# Related PR (if no, explain why):
|
||||
# No, this need CANN add an aclnn shift operation
|
||||
# Future Plan:
|
||||
# Revert this when CANN support shift aclnn operation
|
||||
|
@ -25,6 +25,7 @@ import vllm_ascend.patch.worker.patch_common.patch_attention_selector # noqa
|
||||
import vllm_ascend.patch.worker.patch_common.patch_attention_layer # noqa
|
||||
import vllm_ascend.patch.worker.patch_common.patch_distributed # noqa
|
||||
import vllm_ascend.patch.worker.patch_common.patch_logits # noqa
|
||||
import vllm_ascend.patch.worker.patch_common.patch_roberta # noqa
|
||||
import vllm_ascend.patch.worker.patch_common.patch_weight_loader # noqa
|
||||
import vllm_ascend.patch.worker.patch_common.patch_multimodal_merge # noqa
|
||||
|
||||
|
@ -64,6 +64,7 @@ def _cached_get_attn_backend(
|
||||
use_mla: bool = False,
|
||||
use_sfa: bool = False,
|
||||
has_sink: bool = False,
|
||||
use_sparse: bool = False,
|
||||
) -> type[AttentionBackend]:
|
||||
# Check whether a particular choice of backend was
|
||||
# previously forced.
|
||||
|
88
vllm_ascend/patch/worker/patch_common/patch_roberta.py
Normal file
88
vllm_ascend/patch/worker/patch_common/patch_roberta.py
Normal file
@ -0,0 +1,88 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from vllm.model_executor.models.roberta import (
|
||||
RobertaEmbedding, RobertaForSequenceClassification,
|
||||
replace_roberta_positions)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
# aclgraph does not support shift operator for now
|
||||
# TODO: revert me when aclgraph supports shift operator
|
||||
TOKEN_TYPE_SHIFT = 30
|
||||
TOKEN_TYPE_MULTIPLIER = 1 << 30
|
||||
TOKEN_MASK = TOKEN_TYPE_MULTIPLIER - 1
|
||||
|
||||
|
||||
def _encode_token_type_ids(input_ids: torch.Tensor,
|
||||
token_type_ids: torch.Tensor) -> None:
|
||||
# input_ids can be padded to the right
|
||||
input_ids[:token_type_ids.shape[0]].bitwise_or_(token_type_ids *
|
||||
TOKEN_TYPE_MULTIPLIER)
|
||||
|
||||
|
||||
def _decode_token_type_ids(input_ids: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
token_type_ids = input_ids // TOKEN_TYPE_MULTIPLIER
|
||||
|
||||
input_ids.bitwise_and_(TOKEN_MASK)
|
||||
|
||||
return token_type_ids
|
||||
|
||||
|
||||
def roberta_for_sequence_classification_forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
token_type_ids: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
replace_roberta_positions(input_ids=input_ids,
|
||||
position_ids=positions,
|
||||
padding_idx=self.padding_idx)
|
||||
if token_type_ids is not None:
|
||||
assert self.roberta.config.vocab_size < (1 << TOKEN_TYPE_SHIFT)
|
||||
assert input_ids is not None
|
||||
_encode_token_type_ids(input_ids, token_type_ids)
|
||||
return self.roberta(input_ids=input_ids,
|
||||
positions=positions,
|
||||
inputs_embeds=inputs_embeds,
|
||||
intermediate_tensors=intermediate_tensors)
|
||||
|
||||
|
||||
def roberta_embedding_forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
|
||||
token_type_ids = _decode_token_type_ids(input_ids)
|
||||
|
||||
inputs_embeds = self.word_embeddings(input_ids)
|
||||
position_embeddings = self.position_embeddings(position_ids)
|
||||
|
||||
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
||||
embeddings = inputs_embeds + token_type_embeddings + position_embeddings
|
||||
embeddings = self.LayerNorm(embeddings)
|
||||
return embeddings
|
||||
|
||||
|
||||
RobertaEmbedding.forward = roberta_embedding_forward
|
||||
RobertaForSequenceClassification.forward = roberta_for_sequence_classification_forward
|
@ -134,7 +134,8 @@ class NPUPlatform(Platform):
|
||||
structured_outputs_config = vllm_config.structured_outputs_config
|
||||
|
||||
if (model_config is not None and not model_config.use_mla
|
||||
and not scheduler_config.async_scheduling):
|
||||
and not scheduler_config.async_scheduling
|
||||
and model_config.runner_type != "pooling"):
|
||||
logger.info(
|
||||
"Non-MLA LLMs forcibly disable the chunked prefill feature,"
|
||||
"as the performance of operators supporting this feature "
|
||||
|
@ -77,10 +77,11 @@ from vllm.v1.attention.backends.utils import (
|
||||
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec,
|
||||
KVCacheConfig, KVCacheGroupSpec,
|
||||
KVCacheSpec, MambaSpec,
|
||||
MLAAttentionSpec,
|
||||
from vllm.v1.kv_cache_interface import (AttentionSpec,
|
||||
EncoderOnlyAttentionSpec,
|
||||
FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheGroupSpec, KVCacheSpec,
|
||||
MambaSpec, MLAAttentionSpec,
|
||||
UniformTypeKVCacheSpecs)
|
||||
# yapf: enable
|
||||
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
|
||||
@ -867,8 +868,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
def _make_attention_mask(self, seq_lens, position,
|
||||
attn_state) -> torch.Tensor:
|
||||
# Pooling situation.
|
||||
if self.model_config.runner_type == "pooling" and self.model_config.pooler_config.pooling_type == "CLS":
|
||||
return self.attn_mask_builder.get_pooling_mask(self.device)
|
||||
# Chunk Prefill situation.
|
||||
if attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla and not self.ascend_config.use_sfa:
|
||||
elif attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla and not self.ascend_config.use_sfa:
|
||||
if torch.version.cann.startswith("8.3"):
|
||||
return self.attn_mask_builder.get_splitfuse_attn_mask()
|
||||
else:
|
||||
@ -1426,6 +1430,21 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
# in the same group share the same metadata.
|
||||
for kv_cache_group_id, kv_cache_group_spec in enumerate(
|
||||
self.kv_cache_config.kv_cache_groups):
|
||||
if isinstance(kv_cache_group_spec.kv_cache_spec,
|
||||
EncoderOnlyAttentionSpec):
|
||||
# Encoder-only layers do not have KV cache, so we need to
|
||||
# create a dummy block table and slot mapping for them.
|
||||
blk_table_tensor = torch.zeros(
|
||||
(num_reqs, 1),
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
slot_mapping = torch.zeros(
|
||||
(total_num_scheduled_tokens, ),
|
||||
dtype=torch.int64,
|
||||
device=self.device,
|
||||
)
|
||||
else:
|
||||
blk_table = self.input_batch.block_table[kv_cache_group_id]
|
||||
blk_table_tensor = blk_table.get_device_tensor()
|
||||
slot_mapping = blk_table.slot_mapping_cpu[:
|
||||
@ -1469,7 +1488,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
common_prefix_len = 0
|
||||
extra_attn_metadata_args = {}
|
||||
builder = attn_group.get_metadata_builder()
|
||||
if isinstance(builder, GDNAttentionMetadataBuilder):
|
||||
if isinstance(builder, GDNAttentionMetadataBuilder
|
||||
) or self.model_config.runner_type == "pooling":
|
||||
if use_spec_decode:
|
||||
extra_attn_metadata_args = dict(
|
||||
num_accepted_tokens=self.num_accepted_tokens.
|
||||
@ -2625,7 +2645,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
"""
|
||||
kv_cache_config = deepcopy(kv_cache_config)
|
||||
self.kv_cache_config = kv_cache_config
|
||||
self.initialize_attn_backend(kv_cache_config)
|
||||
self.use_hybrid_blocks = (len(self.attn_groups) > 1)
|
||||
# NOTE: Currently, we determine whether we need `num_accepted_tokens` through `MambaSpec`.
|
||||
self.need_accepted_tokens = any([
|
||||
@ -2634,6 +2653,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
])
|
||||
|
||||
self.may_reinitialize_input_batch(kv_cache_config)
|
||||
self.may_add_encoder_only_layers_to_kv_cache_config()
|
||||
self.initialize_attn_backend(kv_cache_config)
|
||||
|
||||
if self.ascend_config.is_deepseek_sfa:
|
||||
kv_caches = self.initialize_kv_cache_tensors_deepseek_sfa(
|
||||
@ -3089,6 +3110,31 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
kernel_block_sizes=kernel_block_sizes,
|
||||
)
|
||||
|
||||
def may_add_encoder_only_layers_to_kv_cache_config(self) -> None:
|
||||
"""
|
||||
Add encoder-only layers to the KV cache config.
|
||||
"""
|
||||
block_size = self.vllm_config.cache_config.block_size
|
||||
encoder_only_attn_specs: dict[AttentionSpec,
|
||||
list[str]] = defaultdict(list)
|
||||
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
|
||||
for layer_name, attn_module in attn_layers.items():
|
||||
if attn_module.attn_type == AttentionType.ENCODER_ONLY:
|
||||
attn_spec: AttentionSpec = EncoderOnlyAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=self.kv_cache_dtype)
|
||||
encoder_only_attn_specs[attn_spec].append(layer_name)
|
||||
self.runner_only_attn_layers.add(layer_name)
|
||||
if len(encoder_only_attn_specs) > 0:
|
||||
assert len(
|
||||
encoder_only_attn_specs
|
||||
) == 1, "Only support one encoder-only attention spec now"
|
||||
spec, layer_names = encoder_only_attn_specs.popitem()
|
||||
self.kv_cache_config.kv_cache_groups.append(
|
||||
KVCacheGroupSpec(layer_names=layer_names, kv_cache_spec=spec))
|
||||
|
||||
def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None:
|
||||
"""
|
||||
Initialize the attention backends and attention metadata builders.
|
||||
|
Reference in New Issue
Block a user