[1/N][Feat] Add weight prefetch feature for Attention layers (#3146)

### What this PR does / why we need it?

- Refacotr and integrate a unified `WeightPrefetchMethod`
- Integrate `qkv_proj.weight` and `o_proj.weight` in quantized Attention
modules
- Prefetching these weights ahead of matmul-like operators imporves
performance by reducing L2 cache transfer latency

### Does this PR introduce _any_ user-facing change?

Add a new config in `--additional-config` for configuration:
```json
{
    "weight_prefetch_config": {
        "enabled": false,
        "prefetch_ratio": {
            "attn": {
                "qkv": 1.0,
                "o": 1.0,
            },
        },
    },
}
```
This feature is enabled by default, and can be disabled through this
configuration

### How was this patch tested?


- vLLM version: v0.11.0

---------

Signed-off-by: yuzhup <15705211260@163.com>
Signed-off-by: zhoux77899 <zhouxiang100@huawei.com>
Co-authored-by: yuzhup <15705211260@163.com>
This commit is contained in:
Ruri
2025-10-09 20:38:39 +08:00
committed by GitHub
parent 23db56a340
commit ff37575936
13 changed files with 264 additions and 69 deletions

View File

@ -24,24 +24,24 @@ LLM(model="Qwen/Qwen3-8B", additional_config={"config_key":"config_value"})
The following table lists the additional configuration options available in vLLM Ascend: The following table lists the additional configuration options available in vLLM Ascend:
| Name | Type | Default | Description | | Name | Type | Default | Description |
|-------------------------------| ---- |------|-----------------------------------------------------------------------------------------------| |-------------------------------------|------|---------|-----------------------------------------------------------------------------------------------------------------------------------------------|
| `torchair_graph_config` | dict | `{}` | The config options for torchair graph mode | | `torchair_graph_config` | dict | `{}` | The config options for torchair graph mode |
| `ascend_scheduler_config` | dict | `{}` | The config options for ascend scheduler | | `ascend_scheduler_config` | dict | `{}` | The config options for ascend scheduler |
| `refresh` | bool | `false` | Whether to refresh global ascend config content. This value is usually used by rlhf or ut/e2e test case. | | `weight_prefetch_config` | dict | `{}` | The config options for weight prefetch |
| `expert_map_path` | str | `None` | When using expert load balancing for the MOE model, an expert map path needs to be passed in. | | `refresh` | bool | `false` | Whether to refresh global ascend config content. This value is usually used by rlhf or ut/e2e test case. |
| `enable_prefetch` | bool | `False` | Whether to enable weight prefetch. | | `expert_map_path` | str | `None` | When using expert load balancing for the MOE model, an expert map path needs to be passed in. |
| `kv_cache_dtype` | str | `None` | When using the kv cache quantization method, kv cache dtype needs to be set, currently only int8 is supported. | | `kv_cache_dtype` | str | `None` | When using the kv cache quantization method, kv cache dtype needs to be set, currently only int8 is supported. |
| `enable_shared_expert_dp` | bool | `False` | When the shared expert in DP, it has better performance but consumes more memory. Currently only DeepSeek series models are supported to use. | | `enable_shared_expert_dp` | bool | `False` | When the shared expert in DP, it has better performance but consumes more memory. Currently only DeepSeek series models are supported to use. |
| `lmhead_tensor_parallel_size` | int | `None` | The custom tensor parallel size of lmhead. | | `lmhead_tensor_parallel_size` | int | `None` | The custom tensor parallel size of lmhead. |
| `oproj_tensor_parallel_size` | int | `None` | The custom tensor parallel size of oproj. | | `oproj_tensor_parallel_size` | int | `None` | The custom tensor parallel size of oproj. |
| `multistream_overlap_shared_expert`| bool | `False` | Whether to enable multistream shared expert. This option only takes effects on moe models with shared experts. | | `multistream_overlap_shared_expert` | bool | `False` | Whether to enable multistream shared expert. This option only takes effects on moe models with shared experts. |
| `dynamic_eplb` | bool | `False` | Whether to enable dynamic eplb | | `dynamic_eplb` | bool | `False` | Whether to enable dynamic eplb |
|`num_iterations_eplb_update`| int | `400` | Forward iterations when eplb would begin | | `num_iterations_eplb_update` | int | `400` | Forward iterations when eplb would begin |
|`gate_eplb`| bool | `False` | Whether to enale eplb only once. | | `gate_eplb` | bool | `False` | Whether to enale eplb only once. |
|`num_wait_worker_iterations`| int | `30` | The forward iterations when eplb worker will finish cpu task. In our test default value 30 would cover most cases. | | `num_wait_worker_iterations` | int | `30` | The forward iterations when eplb worker will finish cpu task. In our test default value 30 would cover most cases. |
|`expert_map_record_path`| str | `None` | When dynamic eplb is completed, save the current expert load heatmap to the specified path. | | `expert_map_record_path` | str | `None` | When dynamic eplb is completed, save the current expert load heatmap to the specified path. |
|`init_redundancy_expert`| int | `0` |Specify redundant experts during initialization.| | `init_redundancy_expert` | int | `0` | Specify redundant experts during initialization. |
The details of each config option are as follows: The details of each config option are as follows:
@ -71,6 +71,13 @@ The details of each config option are as follows:
ascend_scheduler_config also support the options from [vllm scheduler config](https://docs.vllm.ai/en/stable/api/vllm/config.html#vllm.config.SchedulerConfig). For example, you can add `enable_chunked_prefill: True` to ascend_scheduler_config as well. ascend_scheduler_config also support the options from [vllm scheduler config](https://docs.vllm.ai/en/stable/api/vllm/config.html#vllm.config.SchedulerConfig). For example, you can add `enable_chunked_prefill: True` to ascend_scheduler_config as well.
**weight_prefetch_config**
| Name | Type | Default | Description |
|------------------|------|------------------------------------|------------------------------------|
| `enabled` | bool | `False` | Whether to enable weight prefetch. |
| `prefetch_ratio` | dict | `{"attn": {"qkv": 1.0, "o": 1.0}}` | Prefetch ratio of each weights. |
### Example ### Example
An example of additional configuration is as follows: An example of additional configuration is as follows:
@ -90,6 +97,15 @@ An example of additional configuration is as follows:
"max_long_partial_prefills": 1, "max_long_partial_prefills": 1,
"long_prefill_token_threshold": 4096, "long_prefill_token_threshold": 4096,
}, },
"weight_prefetch_config": {
"enabled": True,
"prefetch_ratio": {
"attn": {
"qkv": 1.0,
"o": 1.0,
},
},
},
"multistream_overlap_shared_expert": True, "multistream_overlap_shared_expert": True,
"refresh": False, "refresh": False,
} }

View File

@ -495,7 +495,7 @@ class TestAscendMLAImpl(TestBase):
mock_up_proj.assert_called_once() mock_up_proj.assert_called_once()
mock_npu_fused_infer_attention_score.assert_called_once() mock_npu_fused_infer_attention_score.assert_called_once()
@patch("vllm_ascend.attention.mla_v1.npu_prefetch") @patch("vllm_ascend.attention.mla_v1.maybe_npu_prefetch")
def test_mla_preprocess(self, magic_npu_fetch): def test_mla_preprocess(self, magic_npu_fetch):
magic_npu_fetch.return_value = MagicMock() magic_npu_fetch.return_value = MagicMock()
batch_size = 4 batch_size = 4

View File

@ -68,16 +68,23 @@ class TestAscendW8A8LinearMethod(TestBase):
self.assertEqual(params['weight_scale'].shape, (10, 1)) self.assertEqual(params['weight_scale'].shape, (10, 1))
self.assertEqual(params['weight_offset'].shape, (10, 1)) self.assertEqual(params['weight_offset'].shape, (10, 1))
@patch("vllm_ascend.quantization.w8a8.get_forward_context")
@patch("vllm_ascend.quantization.w8a8.quant_per_tensor") @patch("vllm_ascend.quantization.w8a8.quant_per_tensor")
@patch("torch_npu.npu_quant_matmul") @patch("torch_npu.npu_quant_matmul")
def test_apply_with_x_not_int8(self, mock_npu_quant_matmul, def test_apply_with_x_not_int8(self, mock_npu_quant_matmul,
mock_quant_per_tensor): mock_quant_per_tensor,
mock_get_forward_context):
layer = MagicMock() layer = MagicMock()
layer.aclnn_input_scale = 0.1 layer.aclnn_input_scale = 0.1
layer.aclnn_input_offset = 0.2 layer.aclnn_input_offset = 0.2
layer.weight = torch.randn(128, 256) layer.weight = torch.randn(128, 256)
layer.deq_scale = 0.3 layer.deq_scale = 0.3
mock_forward_context = MagicMock()
mock_get_forward_context.return_value = mock_forward_context
mock_weight_prefetch_method = MagicMock()
mock_forward_context.weight_prefetch_method = mock_weight_prefetch_method
x = torch.randn(32, 128) x = torch.randn(32, 128)
bias = torch.randn(256) bias = torch.randn(256)
mock_quant_per_tensor.return_value = torch.randint(-128, mock_quant_per_tensor.return_value = torch.randint(-128,

View File

@ -45,6 +45,12 @@ class AscendConfig:
"ascend_scheduler_config", {}) "ascend_scheduler_config", {})
self.ascend_scheduler_config = AscendSchedulerConfig( self.ascend_scheduler_config = AscendSchedulerConfig(
ascend_scheduler_config) ascend_scheduler_config)
weight_prefetch_config = additional_config.get(
"weight_prefetch_config", {})
self.weight_prefetch_config = WeightPrefetchConfig(
weight_prefetch_config)
# Todo: Once https://github.com/vllm-project/vllm/issues/22246 is merged in vllm. Remove this config # Todo: Once https://github.com/vllm-project/vllm/issues/22246 is merged in vllm. Remove this config
self.expert_map_path = additional_config.get("expert_map_path", None) self.expert_map_path = additional_config.get("expert_map_path", None)
self.expert_map_record_path = additional_config.get( self.expert_map_record_path = additional_config.get(
@ -65,7 +71,6 @@ class AscendConfig:
) and not self.torchair_graph_config.enabled and vllm_config.parallel_config.enable_expert_parallel ) and not self.torchair_graph_config.enabled and vllm_config.parallel_config.enable_expert_parallel
self.multistream_overlap_shared_expert = additional_config.get( self.multistream_overlap_shared_expert = additional_config.get(
"multistream_overlap_shared_expert", False) "multistream_overlap_shared_expert", False)
self.enable_prefetch = additional_config.get("enable_prefetch", False)
self.lmhead_tensor_parallel_size = additional_config.get( self.lmhead_tensor_parallel_size = additional_config.get(
"lmhead_tensor_parallel_size", None) "lmhead_tensor_parallel_size", None)
if self.lmhead_tensor_parallel_size is not None: if self.lmhead_tensor_parallel_size is not None:
@ -185,6 +190,24 @@ class AscendSchedulerConfig:
setattr(self, k, v) setattr(self, k, v)
class WeightPrefetchConfig:
"""
Configuration Object for weight_prefetch_config from additional_config
"""
prefetch_ratio: dict = {
"attn": {
"qkv": 1.0,
"o": 1.0,
},
}
def __init__(self, weight_prefetch_config: dict):
self.enabled = weight_prefetch_config.get("enabled", False)
self.prefetch_ratio = weight_prefetch_config.get(
"prefetch_ratio", self.prefetch_ratio)
_ASCEND_CONFIG: Optional[AscendConfig] = None _ASCEND_CONFIG: Optional[AscendConfig] = None

View File

@ -1,7 +1,7 @@
import math import math
from contextlib import contextmanager from contextlib import contextmanager
from enum import Enum from enum import Enum
from typing import Any, Optional from typing import TYPE_CHECKING, Any, Optional
import torch import torch
from vllm.config import CUDAGraphMode, VllmConfig from vllm.config import CUDAGraphMode, VllmConfig
@ -13,6 +13,11 @@ from vllm.forward_context import (BatchDescriptor, get_forward_context,
import vllm_ascend.envs as envs_ascend import vllm_ascend.envs as envs_ascend
from vllm_ascend.utils import enable_sp from vllm_ascend.utils import enable_sp
if TYPE_CHECKING:
from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod
else:
WeightPrefetchMethod = None
class FusedMoEState(Enum): class FusedMoEState(Enum):
AllGather = 0 AllGather = 0
@ -65,7 +70,8 @@ def set_ascend_forward_context(
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor: Optional[BatchDescriptor] = None, batch_descriptor: Optional[BatchDescriptor] = None,
prefetch_stream: torch.npu.Stream = None, prefetch_stream: torch.npu.Stream = None,
model_instance: torch.nn.Module = None): model_instance: torch.nn.Module = None,
weight_prefetch_method: Optional[WeightPrefetchMethod] = None):
"""A context manager that stores the current forward context, """A context manager that stores the current forward context,
can be attention metadata, etc. can be attention metadata, etc.
We add some additional param into forward_context. We add some additional param into forward_context.
@ -127,6 +133,7 @@ def set_ascend_forward_context(
hasattr(model_instance.model, "start_layer"): hasattr(model_instance.model, "start_layer"):
forward_context.layer_idx = model_instance.model.start_layer forward_context.layer_idx = model_instance.model.start_layer
# TODO(rjg-lyh): refactor mlp weight prefetch method
# set for mlp weight prefetch # set for mlp weight prefetch
prefetch_mlp_enabled = envs_ascend.VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE and \ prefetch_mlp_enabled = envs_ascend.VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE and \
envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP and \ envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP and \
@ -138,6 +145,8 @@ def set_ascend_forward_context(
forward_context.prefetch_mlp_gate_up_proj = False forward_context.prefetch_mlp_gate_up_proj = False
forward_context.prefetch_mlp_down_proj = False forward_context.prefetch_mlp_down_proj = False
forward_context.prefetch_mlp_enabled = prefetch_mlp_enabled forward_context.prefetch_mlp_enabled = prefetch_mlp_enabled
# TODO(yuzhup): integrate moe weight prefetch method
forward_context.weight_prefetch_method = weight_prefetch_method
# TODO(rjg-lyh): The current implementation is somewhat brute force and not elegant. # TODO(rjg-lyh): The current implementation is somewhat brute force and not elegant.
# It will be improved later by implementing operator fusion through the FX graph. # It will be improved later by implementing operator fusion through the FX graph.

View File

@ -24,7 +24,7 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
from vllm_ascend.multistream.context import get_multistream_comm_context from vllm_ascend.multistream.context import get_multistream_comm_context
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
from vllm_ascend.utils import npu_prefetch from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
from vllm_ascend.worker.npu_input_batch import InputBatch from vllm_ascend.worker.npu_input_batch import InputBatch
if TYPE_CHECKING: if TYPE_CHECKING:
@ -493,7 +493,7 @@ class AscendMLAImpl(MLAAttentionImpl):
ascend_config = get_ascend_config() ascend_config = get_ascend_config()
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
self.enable_prefetch = ascend_config.enable_prefetch self.enable_prefetch = ascend_config.weight_prefetch_config.enabled
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
@ -877,9 +877,9 @@ class AscendMLAImpl(MLAAttentionImpl):
num_decode_tokens = attn_metadata.num_decode_tokens num_decode_tokens = attn_metadata.num_decode_tokens
num_actual_tokens = attn_metadata.num_actual_tokens num_actual_tokens = attn_metadata.num_actual_tokens
if self.q_a_proj is not None: if self.q_a_proj is not None:
npu_prefetch(self.q_a_proj.weight, maybe_npu_prefetch(inputs=self.q_a_proj.weight,
hidden_states, dependency=hidden_states,
enabled=self.enable_prefetch) enabled=self.enable_prefetch)
ckq = self.q_a_proj(hidden_states)[0] ckq = self.q_a_proj(hidden_states)[0]
q_c = self.q_a_layernorm(ckq) q_c = self.q_a_layernorm(ckq)
else: else:
@ -1005,10 +1005,10 @@ class AscendMLAImpl(MLAAttentionImpl):
current_ms_metadata = get_multistream_comm_context() current_ms_metadata = get_multistream_comm_context()
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024
if current_ms_metadata is None: if current_ms_metadata is None:
npu_prefetch(self.o_proj.weight, maybe_npu_prefetch(inputs=self.o_proj.weight,
o_proj_input, dependency=o_proj_input,
max_size=MAX_O_PROJ_PREFETCH_SIZE, max_size=MAX_O_PROJ_PREFETCH_SIZE,
enabled=self.enable_prefetch) enabled=self.enable_prefetch)
output[...] = self.o_proj( output[...] = self.o_proj(
o_proj_input, o_proj_input,
@ -1016,10 +1016,10 @@ class AscendMLAImpl(MLAAttentionImpl):
is_force_scatter=self.enable_shared_expert_dp)[0] is_force_scatter=self.enable_shared_expert_dp)[0]
else: else:
with torch.npu.stream(current_ms_metadata.comm_stream): with torch.npu.stream(current_ms_metadata.comm_stream):
npu_prefetch(self.o_proj.weight, maybe_npu_prefetch(inputs=self.o_proj.weight,
o_proj_input, dependency=o_proj_input,
max_size=MAX_O_PROJ_PREFETCH_SIZE, max_size=MAX_O_PROJ_PREFETCH_SIZE,
enabled=self.enable_prefetch) enabled=self.enable_prefetch)
output[...] = self.o_proj( output[...] = self.o_proj(
o_proj_input, o_proj_input,
is_prefill=prefill_preprocess_res is not None, is_prefill=prefill_preprocess_res is not None,

View File

@ -11,6 +11,8 @@ from vllm.utils import direct_register_custom_op
import vllm_ascend.envs as envs_ascend import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_forward_context import MoECommType from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
from vllm_ascend.utils import npu_stream_switch, prefetch_stream
def _maybe_chunk_residual_impl(x: torch.Tensor, def _maybe_chunk_residual_impl(x: torch.Tensor,
@ -148,6 +150,33 @@ def _maybe_wait_prefetch_done_impl_fake(x: torch.Tensor) -> None:
return return
def _prefetch_preprocess_impl(weight: torch.Tensor, start_flag: torch.Tensor,
max_weight_size: int) -> None:
calculation_stream = torch_npu.npu.current_stream()
weight_prefetch_stream = prefetch_stream()
weight_prefetch_stream.wait_stream(calculation_stream)
with npu_stream_switch(weight_prefetch_stream):
maybe_npu_prefetch(inputs=weight,
dependency=start_flag,
max_size=max_weight_size)
def _prefetch_preprocess_impl_fake(weight: torch.Tensor,
start_flag: torch.Tensor,
max_weight_size: int) -> None:
return
def _prefetch_postprocess_impl(stop_flag: torch.Tensor) -> None:
calculation_stream = torch_npu.npu.current_stream()
weight_prefetch_stream = prefetch_stream()
calculation_stream.wait_stream(weight_prefetch_stream)
def _prefetch_postprocess_impl_fake(stop_flag: torch.Tensor) -> None:
return
def _maybe_all_reduce_tensor_model_parallel_impl( def _maybe_all_reduce_tensor_model_parallel_impl(
final_hidden_states: torch.Tensor) -> torch.Tensor: final_hidden_states: torch.Tensor) -> torch.Tensor:
forward_context = get_forward_context() forward_context = get_forward_context()
@ -194,6 +223,18 @@ direct_register_custom_op(op_name="maybe_wait_prefetch_done",
mutates_args=[], mutates_args=[],
dispatch_key="PrivateUse1") dispatch_key="PrivateUse1")
direct_register_custom_op(op_name="prefetch_preprocess",
op_func=_prefetch_preprocess_impl,
fake_impl=_prefetch_preprocess_impl_fake,
mutates_args=[],
dispatch_key="PrivateUse1")
direct_register_custom_op(op_name="prefetch_postprocess",
op_func=_prefetch_postprocess_impl,
fake_impl=_prefetch_postprocess_impl_fake,
mutates_args=[],
dispatch_key="PrivateUse1")
direct_register_custom_op(op_name="maybe_all_reduce_tensor_model_parallel", direct_register_custom_op(op_name="maybe_all_reduce_tensor_model_parallel",
op_func=_maybe_all_reduce_tensor_model_parallel_impl, op_func=_maybe_all_reduce_tensor_model_parallel_impl,
fake_impl=lambda x: x, fake_impl=lambda x: x,

View File

@ -0,0 +1,75 @@
from dataclasses import dataclass, field
import torch
import torch_npu
from vllm_ascend.ascend_config import WeightPrefetchConfig
SUPPORTED_MODULES = ["attn", "mlp", "moe"]
@dataclass
class ModuleWeightPrefetchConfig:
module_name: str
enable: bool = False
prefetch_ratio: dict = field(default_factory=dict)
def __post_init__(self) -> None:
self.prefetch_ratio = {
prefix: ratio
for prefix, ratio in self.prefetch_ratio.items() if 0 <= ratio <= 1
}
assert self.module_name in SUPPORTED_MODULES, (
f"Invalid module name {self.module_name}, should be one of {SUPPORTED_MODULES}"
)
if self.module_name in SUPPORTED_MODULES:
self.enable = self.enable and any(self.prefetch_ratio.values()) > 0
class WeightPrefetchMethod:
"""
Unified weight prefetch method.
"""
def __init__(self, weight_prefetch_config: WeightPrefetchConfig) -> None:
self.attn = ModuleWeightPrefetchConfig(
module_name="attn",
enable=weight_prefetch_config.enabled,
prefetch_ratio=weight_prefetch_config.prefetch_ratio.get(
"attn", {}))
def maybe_prefetch_attn_weight_preprocess(
self, prefix: str, weight: torch.Tensor,
start_flag: torch.Tensor) -> None:
if not self.attn.enable:
return
weight_size = weight.data.element_size() * weight.data.numel(
) * self.attn.prefetch_ratio.get(prefix, 0)
torch.ops.vllm.prefetch_preprocess(weight=weight,
start_flag=start_flag,
max_weight_size=int(weight_size))
def maybe_prefetch_attn_weight_postprocess(
self, stop_flag: torch.Tensor) -> None:
if not self.attn.enable:
return
torch.ops.vllm.prefetch_postprocess(stop_flag)
def maybe_npu_prefetch(inputs: torch.Tensor,
dependency: torch.Tensor,
max_size: int = 0,
offset: int = 0,
*,
enabled: bool = True) -> None:
if not enabled:
return
input_size = inputs.element_size() * inputs.numel()
if max_size <= 0 or max_size > input_size:
max_size = input_size
torch_npu.npu_prefetch(inputs, dependency, max_size, offset)

View File

@ -21,6 +21,7 @@ import torch
import torch_npu import torch_npu
from vllm.attention.backends.abstract import AttentionType from vllm.attention.backends.abstract import AttentionType
from vllm.distributed.parallel_state import get_ep_group from vllm.distributed.parallel_state import get_ep_group
from vllm.forward_context import get_forward_context
from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.ops.moe.experts_selector import select_experts from vllm_ascend.ops.moe.experts_selector import select_experts
@ -97,11 +98,32 @@ class AscendW8A8LinearMethod:
tp_rank: Optional[int] = 0, tp_rank: Optional[int] = 0,
) -> torch.Tensor: ) -> torch.Tensor:
if x.dtype != torch.int8: if x.dtype != torch.int8:
attn_weight_map = {
"AscendQKVParallelLinear": "qkv",
"AscendRowParallelLinear": "o",
}
layer_cls_name = layer.__class__.__name__
weight_prefetch_method = get_forward_context(
).weight_prefetch_method
assert weight_prefetch_method is not None
# prefetch_qkvo_proj.weight preprocess
weight_prefetch_method.maybe_prefetch_attn_weight_preprocess(
prefix=attn_weight_map.get(layer_cls_name, ""),
weight=layer.weight,
start_flag=x,
)
# quant
x = quant_per_tensor( x = quant_per_tensor(
x, x,
layer.aclnn_input_scale_reciprocal, layer.aclnn_input_scale_reciprocal,
layer.aclnn_input_offset, layer.aclnn_input_offset,
) )
# prefetch_qkvo_proj.weight postprocess
if layer_cls_name in attn_weight_map.keys():
weight_prefetch_method.maybe_prefetch_attn_weight_postprocess(
x)
quant_bias = layer.quant_bias if tp_rank == 0 else None quant_bias = layer.quant_bias if tp_rank == 0 else None
if is_310p(): if is_310p():
# On 300I Duo platform, we need transpose again if # On 300I Duo platform, we need transpose again if

View File

@ -70,11 +70,12 @@ from vllm.sequence import IntermediateTensors
from vllm_ascend import envs from vllm_ascend import envs
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.models.layers.sfa import Indexer from vllm_ascend.models.layers.sfa import Indexer
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
from vllm_ascend.quantization.quant_config import AscendLinearMethod from vllm_ascend.quantization.quant_config import AscendLinearMethod
from vllm_ascend.torchair.ops.torchair_fused_moe import TorchairAscendFusedMoE from vllm_ascend.torchair.ops.torchair_fused_moe import TorchairAscendFusedMoE
from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import \ from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import \
TorchairAscendW8A8DynamicLinearMethod TorchairAscendW8A8DynamicLinearMethod
from vllm_ascend.utils import dispose_tensor, npu_prefetch, oproj_tp_enable from vllm_ascend.utils import dispose_tensor, oproj_tp_enable
class TorchairDeepseekV2SiluAndMul(SiluAndMul): class TorchairDeepseekV2SiluAndMul(SiluAndMul):
@ -589,9 +590,9 @@ class TorchairDeepseekV2MLAAttention(DeepseekV2MLAAttention):
and attn_metadata.num_decodes > 0) and attn_metadata.num_decodes > 0)
forward_kwargs = {"enable_multistream_mla": enable_multistream_mla} forward_kwargs = {"enable_multistream_mla": enable_multistream_mla}
if self.q_lora_rank is not None: if self.q_lora_rank is not None:
npu_prefetch(self.q_a_proj.weight, maybe_npu_prefetch(self.q_a_proj.weight,
hidden_states, hidden_states,
enabled=enable_multistream_mla) enabled=enable_multistream_mla)
ckq = self.q_a_proj(hidden_states)[0] ckq = self.q_a_proj(hidden_states)[0]
hidden_states_or_q_c = self.q_a_layernorm(ckq) hidden_states_or_q_c = self.q_a_layernorm(ckq)
forward_kwargs['ckq'] = ckq forward_kwargs['ckq'] = ckq

View File

@ -23,9 +23,9 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
from vllm_ascend.multistream.context import get_multistream_comm_context from vllm_ascend.multistream.context import get_multistream_comm_context
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
from vllm_ascend.torchair.utils import (TorchairCommonAttentionMetadata, from vllm_ascend.torchair.utils import (TorchairCommonAttentionMetadata,
npu_stream_switch, npu_wait_tensor) npu_stream_switch, npu_wait_tensor)
from vllm_ascend.utils import npu_prefetch
from vllm_ascend.worker.npu_input_batch import InputBatch from vllm_ascend.worker.npu_input_batch import InputBatch
if TYPE_CHECKING: if TYPE_CHECKING:
@ -684,10 +684,10 @@ class AscendMLATorchairImpl(MLAAttentionImpl):
if hasattr(self, "running_in_graph") and not self.running_in_graph: if hasattr(self, "running_in_graph") and not self.running_in_graph:
return x return x
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 # 16MB MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 # 16MB
npu_prefetch(self.o_proj.weight, maybe_npu_prefetch(self.o_proj.weight,
x, x,
max_size=MAX_O_PROJ_PREFETCH_SIZE, max_size=MAX_O_PROJ_PREFETCH_SIZE,
enabled=enable_multistream_mla) enabled=enable_multistream_mla)
return self.o_proj(x, is_prefill=False)[0] return self.o_proj(x, is_prefill=False)[0]
# Return `ql_nope`, `q_pe` # Return `ql_nope`, `q_pe`
@ -1281,10 +1281,10 @@ class AscendMLATorchairImpl(MLAAttentionImpl):
current_ms_metadata = get_multistream_comm_context() current_ms_metadata = get_multistream_comm_context()
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 # 16MB MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 # 16MB
if current_ms_metadata is None: if current_ms_metadata is None:
npu_prefetch(self.o_proj.weight, maybe_npu_prefetch(self.o_proj.weight,
o_proj_input, o_proj_input,
max_size=MAX_O_PROJ_PREFETCH_SIZE, max_size=MAX_O_PROJ_PREFETCH_SIZE,
enabled=enable_multistream_mla) enabled=enable_multistream_mla)
output[...] = self.o_proj( output[...] = self.o_proj(
o_proj_input, o_proj_input,
@ -1292,10 +1292,10 @@ class AscendMLATorchairImpl(MLAAttentionImpl):
is_force_scatter=self.enable_shared_expert_dp)[0] is_force_scatter=self.enable_shared_expert_dp)[0]
else: else:
with torch.npu.stream(current_ms_metadata.comm_stream): with torch.npu.stream(current_ms_metadata.comm_stream):
npu_prefetch(self.o_proj.weight, maybe_npu_prefetch(self.o_proj.weight,
o_proj_input, o_proj_input,
max_size=MAX_O_PROJ_PREFETCH_SIZE, max_size=MAX_O_PROJ_PREFETCH_SIZE,
enabled=enable_multistream_mla) enabled=enable_multistream_mla)
output[...] = self.o_proj( output[...] = self.o_proj(
o_proj_input, o_proj_input,
is_prefill=True, is_prefill=True,

View File

@ -51,6 +51,7 @@ _CUSTOM_OP_ENABLED = None
_IS_310P = None _IS_310P = None
_SLEEP_MODE_ENABLED = None _SLEEP_MODE_ENABLED = None
_CURRENT_STREAM = None _CURRENT_STREAM = None
_PREFETCH_STREAM = None
_ASCEND_CUSTOMOP_IS_REIGISTERED = False _ASCEND_CUSTOMOP_IS_REIGISTERED = False
@ -241,6 +242,15 @@ def current_stream() -> torch.npu.Stream:
return _CURRENT_STREAM return _CURRENT_STREAM
def prefetch_stream() -> torch.npu.Stream:
global _PREFETCH_STREAM
if _PREFETCH_STREAM is None:
# when this function is called before any stream is set,
# we return the default stream.
_PREFETCH_STREAM = torch_npu.npu.Stream()
return _PREFETCH_STREAM
def adapt_patch(is_global_patch: bool = False): def adapt_patch(is_global_patch: bool = False):
if is_global_patch: if is_global_patch:
from vllm_ascend.patch import platform # noqa: F401 from vllm_ascend.patch import platform # noqa: F401
@ -446,20 +456,6 @@ class ProfileExecuteDuration:
return durations return durations
# TODO(wxy): Move to ops module
def npu_prefetch(input: torch.Tensor,
dependency: torch.Tensor,
max_size: int = 0,
*,
enabled: bool = True):
if not enabled:
return
input_size = input.element_size() * input.numel()
if max_size <= 0 or max_size > input_size:
max_size = input_size
torch_npu.npu_prefetch(input, dependency, max_size)
# TODO(ttanzhiqiang): rm_router_logits # TODO(ttanzhiqiang): rm_router_logits
# dp>1 will trigger # dp>1 will trigger
# In theory, this solution is only applicable to AllGather and AllGatherEP, because in the dp scenario, the previous operation was gate + two communications, and now it is changed to one communication + gate operation, which can save some communication time. In theory, all moe AllGather and AllGatherEP solutions can follow this logic, but now other moe models (qwen3-235b) dp solutions are not adjusted, so use the switch to control it to prevent code errors. # In theory, this solution is only applicable to AllGather and AllGatherEP, because in the dp scenario, the previous operation was gate + two communications, and now it is changed to one communication + gate operation, which can save some communication time. In theory, all moe AllGather and AllGatherEP solutions can follow this logic, but now other moe models (qwen3-235b) dp solutions are not adjusted, so use the switch to control it to prevent code errors.

View File

@ -113,6 +113,7 @@ from vllm_ascend.eplb.eplb_updator import EplbUpdator
from vllm_ascend.eplb.utils import model_register from vllm_ascend.eplb.utils import model_register
from vllm_ascend.models.layers.mla import AscendMultiHeadLatentAttention from vllm_ascend.models.layers.mla import AscendMultiHeadLatentAttention
from vllm_ascend.multistream.ms_split import compute_split_seq_index from vllm_ascend.multistream.ms_split import compute_split_seq_index
from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod
from vllm_ascend.platform import NPUPlatform from vllm_ascend.platform import NPUPlatform
from vllm_ascend.sample.logits_processor import build_logitsprocs from vllm_ascend.sample.logits_processor import build_logitsprocs
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
@ -285,6 +286,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.chunked_prefill_enabled = self.scheduler_config.chunked_prefill_enabled self.chunked_prefill_enabled = self.scheduler_config.chunked_prefill_enabled
else: else:
self.chunked_prefill_enabled = True self.chunked_prefill_enabled = True
self.weight_prefetch_method = WeightPrefetchMethod(
self.ascend_config.weight_prefetch_config)
if self.cache_config.cache_dtype == "auto": if self.cache_config.cache_dtype == "auto":
self.kv_cache_dtype = self.dtype self.kv_cache_dtype = self.dtype
@ -1856,7 +1859,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
num_actual_tokens=scheduler_output. num_actual_tokens=scheduler_output.
total_num_scheduled_tokens, total_num_scheduled_tokens,
prefetch_stream=self.prefetch_stream, prefetch_stream=self.prefetch_stream,
model_instance=self.model): model_instance=self.model,
weight_prefetch_method=self.weight_prefetch_method):
self.maybe_setup_kv_connector(scheduler_output) self.maybe_setup_kv_connector(scheduler_output)
hidden_states = self._generate_process_reqs_hidden_states( hidden_states = self._generate_process_reqs_hidden_states(
@ -2370,7 +2374,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
aclgraph_runtime_mode=aclgraph_runtime_mode, aclgraph_runtime_mode=aclgraph_runtime_mode,
batch_descriptor=batch_descriptor, batch_descriptor=batch_descriptor,
prefetch_stream=self.prefetch_stream, prefetch_stream=self.prefetch_stream,
model_instance=self.model): model_instance=self.model,
weight_prefetch_method=self.weight_prefetch_method):
hidden_states = self._generate_dummy_run_hidden_states( hidden_states = self._generate_dummy_run_hidden_states(
with_prefill, is_torchair_compile, input_ids, positions, with_prefill, is_torchair_compile, input_ids, positions,
attn_metadata, num_tokens, intermediate_tensors, attn_metadata, num_tokens, intermediate_tensors,