mirror of
https://github.com/vllm-project/vllm-ascend.git
synced 2025-10-20 13:43:53 +08:00
[1/N][Feat] Add weight prefetch feature for Attention layers (#3146)
### What this PR does / why we need it? - Refacotr and integrate a unified `WeightPrefetchMethod` - Integrate `qkv_proj.weight` and `o_proj.weight` in quantized Attention modules - Prefetching these weights ahead of matmul-like operators imporves performance by reducing L2 cache transfer latency ### Does this PR introduce _any_ user-facing change? Add a new config in `--additional-config` for configuration: ```json { "weight_prefetch_config": { "enabled": false, "prefetch_ratio": { "attn": { "qkv": 1.0, "o": 1.0, }, }, }, } ``` This feature is enabled by default, and can be disabled through this configuration ### How was this patch tested? - vLLM version: v0.11.0 --------- Signed-off-by: yuzhup <15705211260@163.com> Signed-off-by: zhoux77899 <zhouxiang100@huawei.com> Co-authored-by: yuzhup <15705211260@163.com>
This commit is contained in:
@ -25,12 +25,12 @@ LLM(model="Qwen/Qwen3-8B", additional_config={"config_key":"config_value"})
|
||||
The following table lists the additional configuration options available in vLLM Ascend:
|
||||
|
||||
| Name | Type | Default | Description |
|
||||
|-------------------------------| ---- |------|-----------------------------------------------------------------------------------------------|
|
||||
|-------------------------------------|------|---------|-----------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| `torchair_graph_config` | dict | `{}` | The config options for torchair graph mode |
|
||||
| `ascend_scheduler_config` | dict | `{}` | The config options for ascend scheduler |
|
||||
| `weight_prefetch_config` | dict | `{}` | The config options for weight prefetch |
|
||||
| `refresh` | bool | `false` | Whether to refresh global ascend config content. This value is usually used by rlhf or ut/e2e test case. |
|
||||
| `expert_map_path` | str | `None` | When using expert load balancing for the MOE model, an expert map path needs to be passed in. |
|
||||
| `enable_prefetch` | bool | `False` | Whether to enable weight prefetch. |
|
||||
| `kv_cache_dtype` | str | `None` | When using the kv cache quantization method, kv cache dtype needs to be set, currently only int8 is supported. |
|
||||
| `enable_shared_expert_dp` | bool | `False` | When the shared expert in DP, it has better performance but consumes more memory. Currently only DeepSeek series models are supported to use. |
|
||||
| `lmhead_tensor_parallel_size` | int | `None` | The custom tensor parallel size of lmhead. |
|
||||
@ -71,6 +71,13 @@ The details of each config option are as follows:
|
||||
|
||||
ascend_scheduler_config also support the options from [vllm scheduler config](https://docs.vllm.ai/en/stable/api/vllm/config.html#vllm.config.SchedulerConfig). For example, you can add `enable_chunked_prefill: True` to ascend_scheduler_config as well.
|
||||
|
||||
**weight_prefetch_config**
|
||||
|
||||
| Name | Type | Default | Description |
|
||||
|------------------|------|------------------------------------|------------------------------------|
|
||||
| `enabled` | bool | `False` | Whether to enable weight prefetch. |
|
||||
| `prefetch_ratio` | dict | `{"attn": {"qkv": 1.0, "o": 1.0}}` | Prefetch ratio of each weights. |
|
||||
|
||||
### Example
|
||||
|
||||
An example of additional configuration is as follows:
|
||||
@ -90,6 +97,15 @@ An example of additional configuration is as follows:
|
||||
"max_long_partial_prefills": 1,
|
||||
"long_prefill_token_threshold": 4096,
|
||||
},
|
||||
"weight_prefetch_config": {
|
||||
"enabled": True,
|
||||
"prefetch_ratio": {
|
||||
"attn": {
|
||||
"qkv": 1.0,
|
||||
"o": 1.0,
|
||||
},
|
||||
},
|
||||
},
|
||||
"multistream_overlap_shared_expert": True,
|
||||
"refresh": False,
|
||||
}
|
||||
|
@ -495,7 +495,7 @@ class TestAscendMLAImpl(TestBase):
|
||||
mock_up_proj.assert_called_once()
|
||||
mock_npu_fused_infer_attention_score.assert_called_once()
|
||||
|
||||
@patch("vllm_ascend.attention.mla_v1.npu_prefetch")
|
||||
@patch("vllm_ascend.attention.mla_v1.maybe_npu_prefetch")
|
||||
def test_mla_preprocess(self, magic_npu_fetch):
|
||||
magic_npu_fetch.return_value = MagicMock()
|
||||
batch_size = 4
|
||||
|
@ -68,16 +68,23 @@ class TestAscendW8A8LinearMethod(TestBase):
|
||||
self.assertEqual(params['weight_scale'].shape, (10, 1))
|
||||
self.assertEqual(params['weight_offset'].shape, (10, 1))
|
||||
|
||||
@patch("vllm_ascend.quantization.w8a8.get_forward_context")
|
||||
@patch("vllm_ascend.quantization.w8a8.quant_per_tensor")
|
||||
@patch("torch_npu.npu_quant_matmul")
|
||||
def test_apply_with_x_not_int8(self, mock_npu_quant_matmul,
|
||||
mock_quant_per_tensor):
|
||||
mock_quant_per_tensor,
|
||||
mock_get_forward_context):
|
||||
layer = MagicMock()
|
||||
layer.aclnn_input_scale = 0.1
|
||||
layer.aclnn_input_offset = 0.2
|
||||
layer.weight = torch.randn(128, 256)
|
||||
layer.deq_scale = 0.3
|
||||
|
||||
mock_forward_context = MagicMock()
|
||||
mock_get_forward_context.return_value = mock_forward_context
|
||||
mock_weight_prefetch_method = MagicMock()
|
||||
mock_forward_context.weight_prefetch_method = mock_weight_prefetch_method
|
||||
|
||||
x = torch.randn(32, 128)
|
||||
bias = torch.randn(256)
|
||||
mock_quant_per_tensor.return_value = torch.randint(-128,
|
||||
|
@ -45,6 +45,12 @@ class AscendConfig:
|
||||
"ascend_scheduler_config", {})
|
||||
self.ascend_scheduler_config = AscendSchedulerConfig(
|
||||
ascend_scheduler_config)
|
||||
|
||||
weight_prefetch_config = additional_config.get(
|
||||
"weight_prefetch_config", {})
|
||||
self.weight_prefetch_config = WeightPrefetchConfig(
|
||||
weight_prefetch_config)
|
||||
|
||||
# Todo: Once https://github.com/vllm-project/vllm/issues/22246 is merged in vllm. Remove this config
|
||||
self.expert_map_path = additional_config.get("expert_map_path", None)
|
||||
self.expert_map_record_path = additional_config.get(
|
||||
@ -65,7 +71,6 @@ class AscendConfig:
|
||||
) and not self.torchair_graph_config.enabled and vllm_config.parallel_config.enable_expert_parallel
|
||||
self.multistream_overlap_shared_expert = additional_config.get(
|
||||
"multistream_overlap_shared_expert", False)
|
||||
self.enable_prefetch = additional_config.get("enable_prefetch", False)
|
||||
self.lmhead_tensor_parallel_size = additional_config.get(
|
||||
"lmhead_tensor_parallel_size", None)
|
||||
if self.lmhead_tensor_parallel_size is not None:
|
||||
@ -185,6 +190,24 @@ class AscendSchedulerConfig:
|
||||
setattr(self, k, v)
|
||||
|
||||
|
||||
class WeightPrefetchConfig:
|
||||
"""
|
||||
Configuration Object for weight_prefetch_config from additional_config
|
||||
"""
|
||||
|
||||
prefetch_ratio: dict = {
|
||||
"attn": {
|
||||
"qkv": 1.0,
|
||||
"o": 1.0,
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(self, weight_prefetch_config: dict):
|
||||
self.enabled = weight_prefetch_config.get("enabled", False)
|
||||
self.prefetch_ratio = weight_prefetch_config.get(
|
||||
"prefetch_ratio", self.prefetch_ratio)
|
||||
|
||||
|
||||
_ASCEND_CONFIG: Optional[AscendConfig] = None
|
||||
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
import math
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
from vllm.config import CUDAGraphMode, VllmConfig
|
||||
@ -13,6 +13,11 @@ from vllm.forward_context import (BatchDescriptor, get_forward_context,
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.utils import enable_sp
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod
|
||||
else:
|
||||
WeightPrefetchMethod = None
|
||||
|
||||
|
||||
class FusedMoEState(Enum):
|
||||
AllGather = 0
|
||||
@ -65,7 +70,8 @@ def set_ascend_forward_context(
|
||||
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||
batch_descriptor: Optional[BatchDescriptor] = None,
|
||||
prefetch_stream: torch.npu.Stream = None,
|
||||
model_instance: torch.nn.Module = None):
|
||||
model_instance: torch.nn.Module = None,
|
||||
weight_prefetch_method: Optional[WeightPrefetchMethod] = None):
|
||||
"""A context manager that stores the current forward context,
|
||||
can be attention metadata, etc.
|
||||
We add some additional param into forward_context.
|
||||
@ -127,6 +133,7 @@ def set_ascend_forward_context(
|
||||
hasattr(model_instance.model, "start_layer"):
|
||||
forward_context.layer_idx = model_instance.model.start_layer
|
||||
|
||||
# TODO(rjg-lyh): refactor mlp weight prefetch method
|
||||
# set for mlp weight prefetch
|
||||
prefetch_mlp_enabled = envs_ascend.VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE and \
|
||||
envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP and \
|
||||
@ -138,6 +145,8 @@ def set_ascend_forward_context(
|
||||
forward_context.prefetch_mlp_gate_up_proj = False
|
||||
forward_context.prefetch_mlp_down_proj = False
|
||||
forward_context.prefetch_mlp_enabled = prefetch_mlp_enabled
|
||||
# TODO(yuzhup): integrate moe weight prefetch method
|
||||
forward_context.weight_prefetch_method = weight_prefetch_method
|
||||
|
||||
# TODO(rjg-lyh): The current implementation is somewhat brute force and not elegant.
|
||||
# It will be improved later by implementing operator fusion through the FX graph.
|
||||
|
@ -24,7 +24,7 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
|
||||
from vllm_ascend.multistream.context import get_multistream_comm_context
|
||||
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
|
||||
from vllm_ascend.utils import npu_prefetch
|
||||
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
||||
from vllm_ascend.worker.npu_input_batch import InputBatch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -493,7 +493,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
|
||||
ascend_config = get_ascend_config()
|
||||
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
||||
self.enable_prefetch = ascend_config.enable_prefetch
|
||||
self.enable_prefetch = ascend_config.weight_prefetch_config.enabled
|
||||
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
@ -877,8 +877,8 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
if self.q_a_proj is not None:
|
||||
npu_prefetch(self.q_a_proj.weight,
|
||||
hidden_states,
|
||||
maybe_npu_prefetch(inputs=self.q_a_proj.weight,
|
||||
dependency=hidden_states,
|
||||
enabled=self.enable_prefetch)
|
||||
ckq = self.q_a_proj(hidden_states)[0]
|
||||
q_c = self.q_a_layernorm(ckq)
|
||||
@ -1005,8 +1005,8 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
current_ms_metadata = get_multistream_comm_context()
|
||||
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024
|
||||
if current_ms_metadata is None:
|
||||
npu_prefetch(self.o_proj.weight,
|
||||
o_proj_input,
|
||||
maybe_npu_prefetch(inputs=self.o_proj.weight,
|
||||
dependency=o_proj_input,
|
||||
max_size=MAX_O_PROJ_PREFETCH_SIZE,
|
||||
enabled=self.enable_prefetch)
|
||||
|
||||
@ -1016,8 +1016,8 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
is_force_scatter=self.enable_shared_expert_dp)[0]
|
||||
else:
|
||||
with torch.npu.stream(current_ms_metadata.comm_stream):
|
||||
npu_prefetch(self.o_proj.weight,
|
||||
o_proj_input,
|
||||
maybe_npu_prefetch(inputs=self.o_proj.weight,
|
||||
dependency=o_proj_input,
|
||||
max_size=MAX_O_PROJ_PREFETCH_SIZE,
|
||||
enabled=self.enable_prefetch)
|
||||
output[...] = self.o_proj(
|
||||
|
@ -11,6 +11,8 @@ from vllm.utils import direct_register_custom_op
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_forward_context import MoECommType
|
||||
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
||||
from vllm_ascend.utils import npu_stream_switch, prefetch_stream
|
||||
|
||||
|
||||
def _maybe_chunk_residual_impl(x: torch.Tensor,
|
||||
@ -148,6 +150,33 @@ def _maybe_wait_prefetch_done_impl_fake(x: torch.Tensor) -> None:
|
||||
return
|
||||
|
||||
|
||||
def _prefetch_preprocess_impl(weight: torch.Tensor, start_flag: torch.Tensor,
|
||||
max_weight_size: int) -> None:
|
||||
calculation_stream = torch_npu.npu.current_stream()
|
||||
weight_prefetch_stream = prefetch_stream()
|
||||
weight_prefetch_stream.wait_stream(calculation_stream)
|
||||
with npu_stream_switch(weight_prefetch_stream):
|
||||
maybe_npu_prefetch(inputs=weight,
|
||||
dependency=start_flag,
|
||||
max_size=max_weight_size)
|
||||
|
||||
|
||||
def _prefetch_preprocess_impl_fake(weight: torch.Tensor,
|
||||
start_flag: torch.Tensor,
|
||||
max_weight_size: int) -> None:
|
||||
return
|
||||
|
||||
|
||||
def _prefetch_postprocess_impl(stop_flag: torch.Tensor) -> None:
|
||||
calculation_stream = torch_npu.npu.current_stream()
|
||||
weight_prefetch_stream = prefetch_stream()
|
||||
calculation_stream.wait_stream(weight_prefetch_stream)
|
||||
|
||||
|
||||
def _prefetch_postprocess_impl_fake(stop_flag: torch.Tensor) -> None:
|
||||
return
|
||||
|
||||
|
||||
def _maybe_all_reduce_tensor_model_parallel_impl(
|
||||
final_hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
forward_context = get_forward_context()
|
||||
@ -194,6 +223,18 @@ direct_register_custom_op(op_name="maybe_wait_prefetch_done",
|
||||
mutates_args=[],
|
||||
dispatch_key="PrivateUse1")
|
||||
|
||||
direct_register_custom_op(op_name="prefetch_preprocess",
|
||||
op_func=_prefetch_preprocess_impl,
|
||||
fake_impl=_prefetch_preprocess_impl_fake,
|
||||
mutates_args=[],
|
||||
dispatch_key="PrivateUse1")
|
||||
|
||||
direct_register_custom_op(op_name="prefetch_postprocess",
|
||||
op_func=_prefetch_postprocess_impl,
|
||||
fake_impl=_prefetch_postprocess_impl_fake,
|
||||
mutates_args=[],
|
||||
dispatch_key="PrivateUse1")
|
||||
|
||||
direct_register_custom_op(op_name="maybe_all_reduce_tensor_model_parallel",
|
||||
op_func=_maybe_all_reduce_tensor_model_parallel_impl,
|
||||
fake_impl=lambda x: x,
|
||||
|
75
vllm_ascend/ops/weight_prefetch.py
Normal file
75
vllm_ascend/ops/weight_prefetch.py
Normal file
@ -0,0 +1,75 @@
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
|
||||
from vllm_ascend.ascend_config import WeightPrefetchConfig
|
||||
|
||||
SUPPORTED_MODULES = ["attn", "mlp", "moe"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModuleWeightPrefetchConfig:
|
||||
module_name: str
|
||||
enable: bool = False
|
||||
prefetch_ratio: dict = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.prefetch_ratio = {
|
||||
prefix: ratio
|
||||
for prefix, ratio in self.prefetch_ratio.items() if 0 <= ratio <= 1
|
||||
}
|
||||
|
||||
assert self.module_name in SUPPORTED_MODULES, (
|
||||
f"Invalid module name {self.module_name}, should be one of {SUPPORTED_MODULES}"
|
||||
)
|
||||
|
||||
if self.module_name in SUPPORTED_MODULES:
|
||||
self.enable = self.enable and any(self.prefetch_ratio.values()) > 0
|
||||
|
||||
|
||||
class WeightPrefetchMethod:
|
||||
"""
|
||||
Unified weight prefetch method.
|
||||
"""
|
||||
|
||||
def __init__(self, weight_prefetch_config: WeightPrefetchConfig) -> None:
|
||||
self.attn = ModuleWeightPrefetchConfig(
|
||||
module_name="attn",
|
||||
enable=weight_prefetch_config.enabled,
|
||||
prefetch_ratio=weight_prefetch_config.prefetch_ratio.get(
|
||||
"attn", {}))
|
||||
|
||||
def maybe_prefetch_attn_weight_preprocess(
|
||||
self, prefix: str, weight: torch.Tensor,
|
||||
start_flag: torch.Tensor) -> None:
|
||||
if not self.attn.enable:
|
||||
return
|
||||
|
||||
weight_size = weight.data.element_size() * weight.data.numel(
|
||||
) * self.attn.prefetch_ratio.get(prefix, 0)
|
||||
|
||||
torch.ops.vllm.prefetch_preprocess(weight=weight,
|
||||
start_flag=start_flag,
|
||||
max_weight_size=int(weight_size))
|
||||
|
||||
def maybe_prefetch_attn_weight_postprocess(
|
||||
self, stop_flag: torch.Tensor) -> None:
|
||||
if not self.attn.enable:
|
||||
return
|
||||
|
||||
torch.ops.vllm.prefetch_postprocess(stop_flag)
|
||||
|
||||
|
||||
def maybe_npu_prefetch(inputs: torch.Tensor,
|
||||
dependency: torch.Tensor,
|
||||
max_size: int = 0,
|
||||
offset: int = 0,
|
||||
*,
|
||||
enabled: bool = True) -> None:
|
||||
if not enabled:
|
||||
return
|
||||
input_size = inputs.element_size() * inputs.numel()
|
||||
if max_size <= 0 or max_size > input_size:
|
||||
max_size = input_size
|
||||
torch_npu.npu_prefetch(inputs, dependency, max_size, offset)
|
@ -21,6 +21,7 @@ import torch
|
||||
import torch_npu
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.distributed.parallel_state import get_ep_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.ops.moe.experts_selector import select_experts
|
||||
@ -97,11 +98,32 @@ class AscendW8A8LinearMethod:
|
||||
tp_rank: Optional[int] = 0,
|
||||
) -> torch.Tensor:
|
||||
if x.dtype != torch.int8:
|
||||
attn_weight_map = {
|
||||
"AscendQKVParallelLinear": "qkv",
|
||||
"AscendRowParallelLinear": "o",
|
||||
}
|
||||
layer_cls_name = layer.__class__.__name__
|
||||
weight_prefetch_method = get_forward_context(
|
||||
).weight_prefetch_method
|
||||
assert weight_prefetch_method is not None
|
||||
|
||||
# prefetch_qkvo_proj.weight preprocess
|
||||
weight_prefetch_method.maybe_prefetch_attn_weight_preprocess(
|
||||
prefix=attn_weight_map.get(layer_cls_name, ""),
|
||||
weight=layer.weight,
|
||||
start_flag=x,
|
||||
)
|
||||
# quant
|
||||
x = quant_per_tensor(
|
||||
x,
|
||||
layer.aclnn_input_scale_reciprocal,
|
||||
layer.aclnn_input_offset,
|
||||
)
|
||||
# prefetch_qkvo_proj.weight postprocess
|
||||
if layer_cls_name in attn_weight_map.keys():
|
||||
weight_prefetch_method.maybe_prefetch_attn_weight_postprocess(
|
||||
x)
|
||||
|
||||
quant_bias = layer.quant_bias if tp_rank == 0 else None
|
||||
if is_310p():
|
||||
# On 300I Duo platform, we need transpose again if
|
||||
|
@ -70,11 +70,12 @@ from vllm.sequence import IntermediateTensors
|
||||
from vllm_ascend import envs
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.models.layers.sfa import Indexer
|
||||
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
||||
from vllm_ascend.quantization.quant_config import AscendLinearMethod
|
||||
from vllm_ascend.torchair.ops.torchair_fused_moe import TorchairAscendFusedMoE
|
||||
from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import \
|
||||
TorchairAscendW8A8DynamicLinearMethod
|
||||
from vllm_ascend.utils import dispose_tensor, npu_prefetch, oproj_tp_enable
|
||||
from vllm_ascend.utils import dispose_tensor, oproj_tp_enable
|
||||
|
||||
|
||||
class TorchairDeepseekV2SiluAndMul(SiluAndMul):
|
||||
@ -589,7 +590,7 @@ class TorchairDeepseekV2MLAAttention(DeepseekV2MLAAttention):
|
||||
and attn_metadata.num_decodes > 0)
|
||||
forward_kwargs = {"enable_multistream_mla": enable_multistream_mla}
|
||||
if self.q_lora_rank is not None:
|
||||
npu_prefetch(self.q_a_proj.weight,
|
||||
maybe_npu_prefetch(self.q_a_proj.weight,
|
||||
hidden_states,
|
||||
enabled=enable_multistream_mla)
|
||||
ckq = self.q_a_proj(hidden_states)[0]
|
||||
|
@ -23,9 +23,9 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
|
||||
from vllm_ascend.multistream.context import get_multistream_comm_context
|
||||
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
|
||||
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
||||
from vllm_ascend.torchair.utils import (TorchairCommonAttentionMetadata,
|
||||
npu_stream_switch, npu_wait_tensor)
|
||||
from vllm_ascend.utils import npu_prefetch
|
||||
from vllm_ascend.worker.npu_input_batch import InputBatch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -684,7 +684,7 @@ class AscendMLATorchairImpl(MLAAttentionImpl):
|
||||
if hasattr(self, "running_in_graph") and not self.running_in_graph:
|
||||
return x
|
||||
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 # 16MB
|
||||
npu_prefetch(self.o_proj.weight,
|
||||
maybe_npu_prefetch(self.o_proj.weight,
|
||||
x,
|
||||
max_size=MAX_O_PROJ_PREFETCH_SIZE,
|
||||
enabled=enable_multistream_mla)
|
||||
@ -1281,7 +1281,7 @@ class AscendMLATorchairImpl(MLAAttentionImpl):
|
||||
current_ms_metadata = get_multistream_comm_context()
|
||||
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 # 16MB
|
||||
if current_ms_metadata is None:
|
||||
npu_prefetch(self.o_proj.weight,
|
||||
maybe_npu_prefetch(self.o_proj.weight,
|
||||
o_proj_input,
|
||||
max_size=MAX_O_PROJ_PREFETCH_SIZE,
|
||||
enabled=enable_multistream_mla)
|
||||
@ -1292,7 +1292,7 @@ class AscendMLATorchairImpl(MLAAttentionImpl):
|
||||
is_force_scatter=self.enable_shared_expert_dp)[0]
|
||||
else:
|
||||
with torch.npu.stream(current_ms_metadata.comm_stream):
|
||||
npu_prefetch(self.o_proj.weight,
|
||||
maybe_npu_prefetch(self.o_proj.weight,
|
||||
o_proj_input,
|
||||
max_size=MAX_O_PROJ_PREFETCH_SIZE,
|
||||
enabled=enable_multistream_mla)
|
||||
|
@ -51,6 +51,7 @@ _CUSTOM_OP_ENABLED = None
|
||||
_IS_310P = None
|
||||
_SLEEP_MODE_ENABLED = None
|
||||
_CURRENT_STREAM = None
|
||||
_PREFETCH_STREAM = None
|
||||
_ASCEND_CUSTOMOP_IS_REIGISTERED = False
|
||||
|
||||
|
||||
@ -241,6 +242,15 @@ def current_stream() -> torch.npu.Stream:
|
||||
return _CURRENT_STREAM
|
||||
|
||||
|
||||
def prefetch_stream() -> torch.npu.Stream:
|
||||
global _PREFETCH_STREAM
|
||||
if _PREFETCH_STREAM is None:
|
||||
# when this function is called before any stream is set,
|
||||
# we return the default stream.
|
||||
_PREFETCH_STREAM = torch_npu.npu.Stream()
|
||||
return _PREFETCH_STREAM
|
||||
|
||||
|
||||
def adapt_patch(is_global_patch: bool = False):
|
||||
if is_global_patch:
|
||||
from vllm_ascend.patch import platform # noqa: F401
|
||||
@ -446,20 +456,6 @@ class ProfileExecuteDuration:
|
||||
return durations
|
||||
|
||||
|
||||
# TODO(wxy): Move to ops module
|
||||
def npu_prefetch(input: torch.Tensor,
|
||||
dependency: torch.Tensor,
|
||||
max_size: int = 0,
|
||||
*,
|
||||
enabled: bool = True):
|
||||
if not enabled:
|
||||
return
|
||||
input_size = input.element_size() * input.numel()
|
||||
if max_size <= 0 or max_size > input_size:
|
||||
max_size = input_size
|
||||
torch_npu.npu_prefetch(input, dependency, max_size)
|
||||
|
||||
|
||||
# TODO(ttanzhiqiang): rm_router_logits
|
||||
# dp>1 will trigger
|
||||
# In theory, this solution is only applicable to AllGather and AllGatherEP, because in the dp scenario, the previous operation was gate + two communications, and now it is changed to one communication + gate operation, which can save some communication time. In theory, all moe AllGather and AllGatherEP solutions can follow this logic, but now other moe models (qwen3-235b) dp solutions are not adjusted, so use the switch to control it to prevent code errors.
|
||||
|
@ -113,6 +113,7 @@ from vllm_ascend.eplb.eplb_updator import EplbUpdator
|
||||
from vllm_ascend.eplb.utils import model_register
|
||||
from vllm_ascend.models.layers.mla import AscendMultiHeadLatentAttention
|
||||
from vllm_ascend.multistream.ms_split import compute_split_seq_index
|
||||
from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
from vllm_ascend.sample.logits_processor import build_logitsprocs
|
||||
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
|
||||
@ -285,6 +286,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.chunked_prefill_enabled = self.scheduler_config.chunked_prefill_enabled
|
||||
else:
|
||||
self.chunked_prefill_enabled = True
|
||||
self.weight_prefetch_method = WeightPrefetchMethod(
|
||||
self.ascend_config.weight_prefetch_config)
|
||||
|
||||
if self.cache_config.cache_dtype == "auto":
|
||||
self.kv_cache_dtype = self.dtype
|
||||
@ -1856,7 +1859,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_actual_tokens=scheduler_output.
|
||||
total_num_scheduled_tokens,
|
||||
prefetch_stream=self.prefetch_stream,
|
||||
model_instance=self.model):
|
||||
model_instance=self.model,
|
||||
weight_prefetch_method=self.weight_prefetch_method):
|
||||
self.maybe_setup_kv_connector(scheduler_output)
|
||||
|
||||
hidden_states = self._generate_process_reqs_hidden_states(
|
||||
@ -2370,7 +2374,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||||
batch_descriptor=batch_descriptor,
|
||||
prefetch_stream=self.prefetch_stream,
|
||||
model_instance=self.model):
|
||||
model_instance=self.model,
|
||||
weight_prefetch_method=self.weight_prefetch_method):
|
||||
hidden_states = self._generate_dummy_run_hidden_states(
|
||||
with_prefill, is_torchair_compile, input_ids, positions,
|
||||
attn_metadata, num_tokens, intermediate_tensors,
|
||||
|
Reference in New Issue
Block a user