mirror of
https://github.com/vllm-project/vllm-ascend.git
synced 2025-10-20 13:43:53 +08:00
[1/N][Feat] Add weight prefetch feature for Attention layers (#3146)
### What this PR does / why we need it? - Refacotr and integrate a unified `WeightPrefetchMethod` - Integrate `qkv_proj.weight` and `o_proj.weight` in quantized Attention modules - Prefetching these weights ahead of matmul-like operators imporves performance by reducing L2 cache transfer latency ### Does this PR introduce _any_ user-facing change? Add a new config in `--additional-config` for configuration: ```json { "weight_prefetch_config": { "enabled": false, "prefetch_ratio": { "attn": { "qkv": 1.0, "o": 1.0, }, }, }, } ``` This feature is enabled by default, and can be disabled through this configuration ### How was this patch tested? - vLLM version: v0.11.0 --------- Signed-off-by: yuzhup <15705211260@163.com> Signed-off-by: zhoux77899 <zhouxiang100@huawei.com> Co-authored-by: yuzhup <15705211260@163.com>
This commit is contained in:
@ -24,24 +24,24 @@ LLM(model="Qwen/Qwen3-8B", additional_config={"config_key":"config_value"})
|
|||||||
|
|
||||||
The following table lists the additional configuration options available in vLLM Ascend:
|
The following table lists the additional configuration options available in vLLM Ascend:
|
||||||
|
|
||||||
| Name | Type | Default | Description |
|
| Name | Type | Default | Description |
|
||||||
|-------------------------------| ---- |------|-----------------------------------------------------------------------------------------------|
|
|-------------------------------------|------|---------|-----------------------------------------------------------------------------------------------------------------------------------------------|
|
||||||
| `torchair_graph_config` | dict | `{}` | The config options for torchair graph mode |
|
| `torchair_graph_config` | dict | `{}` | The config options for torchair graph mode |
|
||||||
| `ascend_scheduler_config` | dict | `{}` | The config options for ascend scheduler |
|
| `ascend_scheduler_config` | dict | `{}` | The config options for ascend scheduler |
|
||||||
| `refresh` | bool | `false` | Whether to refresh global ascend config content. This value is usually used by rlhf or ut/e2e test case. |
|
| `weight_prefetch_config` | dict | `{}` | The config options for weight prefetch |
|
||||||
| `expert_map_path` | str | `None` | When using expert load balancing for the MOE model, an expert map path needs to be passed in. |
|
| `refresh` | bool | `false` | Whether to refresh global ascend config content. This value is usually used by rlhf or ut/e2e test case. |
|
||||||
| `enable_prefetch` | bool | `False` | Whether to enable weight prefetch. |
|
| `expert_map_path` | str | `None` | When using expert load balancing for the MOE model, an expert map path needs to be passed in. |
|
||||||
| `kv_cache_dtype` | str | `None` | When using the kv cache quantization method, kv cache dtype needs to be set, currently only int8 is supported. |
|
| `kv_cache_dtype` | str | `None` | When using the kv cache quantization method, kv cache dtype needs to be set, currently only int8 is supported. |
|
||||||
| `enable_shared_expert_dp` | bool | `False` | When the shared expert in DP, it has better performance but consumes more memory. Currently only DeepSeek series models are supported to use. |
|
| `enable_shared_expert_dp` | bool | `False` | When the shared expert in DP, it has better performance but consumes more memory. Currently only DeepSeek series models are supported to use. |
|
||||||
| `lmhead_tensor_parallel_size` | int | `None` | The custom tensor parallel size of lmhead. |
|
| `lmhead_tensor_parallel_size` | int | `None` | The custom tensor parallel size of lmhead. |
|
||||||
| `oproj_tensor_parallel_size` | int | `None` | The custom tensor parallel size of oproj. |
|
| `oproj_tensor_parallel_size` | int | `None` | The custom tensor parallel size of oproj. |
|
||||||
| `multistream_overlap_shared_expert`| bool | `False` | Whether to enable multistream shared expert. This option only takes effects on moe models with shared experts. |
|
| `multistream_overlap_shared_expert` | bool | `False` | Whether to enable multistream shared expert. This option only takes effects on moe models with shared experts. |
|
||||||
| `dynamic_eplb` | bool | `False` | Whether to enable dynamic eplb |
|
| `dynamic_eplb` | bool | `False` | Whether to enable dynamic eplb |
|
||||||
|`num_iterations_eplb_update`| int | `400` | Forward iterations when eplb would begin |
|
| `num_iterations_eplb_update` | int | `400` | Forward iterations when eplb would begin |
|
||||||
|`gate_eplb`| bool | `False` | Whether to enale eplb only once. |
|
| `gate_eplb` | bool | `False` | Whether to enale eplb only once. |
|
||||||
|`num_wait_worker_iterations`| int | `30` | The forward iterations when eplb worker will finish cpu task. In our test default value 30 would cover most cases. |
|
| `num_wait_worker_iterations` | int | `30` | The forward iterations when eplb worker will finish cpu task. In our test default value 30 would cover most cases. |
|
||||||
|`expert_map_record_path`| str | `None` | When dynamic eplb is completed, save the current expert load heatmap to the specified path. |
|
| `expert_map_record_path` | str | `None` | When dynamic eplb is completed, save the current expert load heatmap to the specified path. |
|
||||||
|`init_redundancy_expert`| int | `0` |Specify redundant experts during initialization.|
|
| `init_redundancy_expert` | int | `0` | Specify redundant experts during initialization. |
|
||||||
|
|
||||||
The details of each config option are as follows:
|
The details of each config option are as follows:
|
||||||
|
|
||||||
@ -71,6 +71,13 @@ The details of each config option are as follows:
|
|||||||
|
|
||||||
ascend_scheduler_config also support the options from [vllm scheduler config](https://docs.vllm.ai/en/stable/api/vllm/config.html#vllm.config.SchedulerConfig). For example, you can add `enable_chunked_prefill: True` to ascend_scheduler_config as well.
|
ascend_scheduler_config also support the options from [vllm scheduler config](https://docs.vllm.ai/en/stable/api/vllm/config.html#vllm.config.SchedulerConfig). For example, you can add `enable_chunked_prefill: True` to ascend_scheduler_config as well.
|
||||||
|
|
||||||
|
**weight_prefetch_config**
|
||||||
|
|
||||||
|
| Name | Type | Default | Description |
|
||||||
|
|------------------|------|------------------------------------|------------------------------------|
|
||||||
|
| `enabled` | bool | `False` | Whether to enable weight prefetch. |
|
||||||
|
| `prefetch_ratio` | dict | `{"attn": {"qkv": 1.0, "o": 1.0}}` | Prefetch ratio of each weights. |
|
||||||
|
|
||||||
### Example
|
### Example
|
||||||
|
|
||||||
An example of additional configuration is as follows:
|
An example of additional configuration is as follows:
|
||||||
@ -90,6 +97,15 @@ An example of additional configuration is as follows:
|
|||||||
"max_long_partial_prefills": 1,
|
"max_long_partial_prefills": 1,
|
||||||
"long_prefill_token_threshold": 4096,
|
"long_prefill_token_threshold": 4096,
|
||||||
},
|
},
|
||||||
|
"weight_prefetch_config": {
|
||||||
|
"enabled": True,
|
||||||
|
"prefetch_ratio": {
|
||||||
|
"attn": {
|
||||||
|
"qkv": 1.0,
|
||||||
|
"o": 1.0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
"multistream_overlap_shared_expert": True,
|
"multistream_overlap_shared_expert": True,
|
||||||
"refresh": False,
|
"refresh": False,
|
||||||
}
|
}
|
||||||
|
@ -495,7 +495,7 @@ class TestAscendMLAImpl(TestBase):
|
|||||||
mock_up_proj.assert_called_once()
|
mock_up_proj.assert_called_once()
|
||||||
mock_npu_fused_infer_attention_score.assert_called_once()
|
mock_npu_fused_infer_attention_score.assert_called_once()
|
||||||
|
|
||||||
@patch("vllm_ascend.attention.mla_v1.npu_prefetch")
|
@patch("vllm_ascend.attention.mla_v1.maybe_npu_prefetch")
|
||||||
def test_mla_preprocess(self, magic_npu_fetch):
|
def test_mla_preprocess(self, magic_npu_fetch):
|
||||||
magic_npu_fetch.return_value = MagicMock()
|
magic_npu_fetch.return_value = MagicMock()
|
||||||
batch_size = 4
|
batch_size = 4
|
||||||
|
@ -68,16 +68,23 @@ class TestAscendW8A8LinearMethod(TestBase):
|
|||||||
self.assertEqual(params['weight_scale'].shape, (10, 1))
|
self.assertEqual(params['weight_scale'].shape, (10, 1))
|
||||||
self.assertEqual(params['weight_offset'].shape, (10, 1))
|
self.assertEqual(params['weight_offset'].shape, (10, 1))
|
||||||
|
|
||||||
|
@patch("vllm_ascend.quantization.w8a8.get_forward_context")
|
||||||
@patch("vllm_ascend.quantization.w8a8.quant_per_tensor")
|
@patch("vllm_ascend.quantization.w8a8.quant_per_tensor")
|
||||||
@patch("torch_npu.npu_quant_matmul")
|
@patch("torch_npu.npu_quant_matmul")
|
||||||
def test_apply_with_x_not_int8(self, mock_npu_quant_matmul,
|
def test_apply_with_x_not_int8(self, mock_npu_quant_matmul,
|
||||||
mock_quant_per_tensor):
|
mock_quant_per_tensor,
|
||||||
|
mock_get_forward_context):
|
||||||
layer = MagicMock()
|
layer = MagicMock()
|
||||||
layer.aclnn_input_scale = 0.1
|
layer.aclnn_input_scale = 0.1
|
||||||
layer.aclnn_input_offset = 0.2
|
layer.aclnn_input_offset = 0.2
|
||||||
layer.weight = torch.randn(128, 256)
|
layer.weight = torch.randn(128, 256)
|
||||||
layer.deq_scale = 0.3
|
layer.deq_scale = 0.3
|
||||||
|
|
||||||
|
mock_forward_context = MagicMock()
|
||||||
|
mock_get_forward_context.return_value = mock_forward_context
|
||||||
|
mock_weight_prefetch_method = MagicMock()
|
||||||
|
mock_forward_context.weight_prefetch_method = mock_weight_prefetch_method
|
||||||
|
|
||||||
x = torch.randn(32, 128)
|
x = torch.randn(32, 128)
|
||||||
bias = torch.randn(256)
|
bias = torch.randn(256)
|
||||||
mock_quant_per_tensor.return_value = torch.randint(-128,
|
mock_quant_per_tensor.return_value = torch.randint(-128,
|
||||||
|
@ -45,6 +45,12 @@ class AscendConfig:
|
|||||||
"ascend_scheduler_config", {})
|
"ascend_scheduler_config", {})
|
||||||
self.ascend_scheduler_config = AscendSchedulerConfig(
|
self.ascend_scheduler_config = AscendSchedulerConfig(
|
||||||
ascend_scheduler_config)
|
ascend_scheduler_config)
|
||||||
|
|
||||||
|
weight_prefetch_config = additional_config.get(
|
||||||
|
"weight_prefetch_config", {})
|
||||||
|
self.weight_prefetch_config = WeightPrefetchConfig(
|
||||||
|
weight_prefetch_config)
|
||||||
|
|
||||||
# Todo: Once https://github.com/vllm-project/vllm/issues/22246 is merged in vllm. Remove this config
|
# Todo: Once https://github.com/vllm-project/vllm/issues/22246 is merged in vllm. Remove this config
|
||||||
self.expert_map_path = additional_config.get("expert_map_path", None)
|
self.expert_map_path = additional_config.get("expert_map_path", None)
|
||||||
self.expert_map_record_path = additional_config.get(
|
self.expert_map_record_path = additional_config.get(
|
||||||
@ -65,7 +71,6 @@ class AscendConfig:
|
|||||||
) and not self.torchair_graph_config.enabled and vllm_config.parallel_config.enable_expert_parallel
|
) and not self.torchair_graph_config.enabled and vllm_config.parallel_config.enable_expert_parallel
|
||||||
self.multistream_overlap_shared_expert = additional_config.get(
|
self.multistream_overlap_shared_expert = additional_config.get(
|
||||||
"multistream_overlap_shared_expert", False)
|
"multistream_overlap_shared_expert", False)
|
||||||
self.enable_prefetch = additional_config.get("enable_prefetch", False)
|
|
||||||
self.lmhead_tensor_parallel_size = additional_config.get(
|
self.lmhead_tensor_parallel_size = additional_config.get(
|
||||||
"lmhead_tensor_parallel_size", None)
|
"lmhead_tensor_parallel_size", None)
|
||||||
if self.lmhead_tensor_parallel_size is not None:
|
if self.lmhead_tensor_parallel_size is not None:
|
||||||
@ -185,6 +190,24 @@ class AscendSchedulerConfig:
|
|||||||
setattr(self, k, v)
|
setattr(self, k, v)
|
||||||
|
|
||||||
|
|
||||||
|
class WeightPrefetchConfig:
|
||||||
|
"""
|
||||||
|
Configuration Object for weight_prefetch_config from additional_config
|
||||||
|
"""
|
||||||
|
|
||||||
|
prefetch_ratio: dict = {
|
||||||
|
"attn": {
|
||||||
|
"qkv": 1.0,
|
||||||
|
"o": 1.0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self, weight_prefetch_config: dict):
|
||||||
|
self.enabled = weight_prefetch_config.get("enabled", False)
|
||||||
|
self.prefetch_ratio = weight_prefetch_config.get(
|
||||||
|
"prefetch_ratio", self.prefetch_ratio)
|
||||||
|
|
||||||
|
|
||||||
_ASCEND_CONFIG: Optional[AscendConfig] = None
|
_ASCEND_CONFIG: Optional[AscendConfig] = None
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import math
|
import math
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Optional
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from vllm.config import CUDAGraphMode, VllmConfig
|
from vllm.config import CUDAGraphMode, VllmConfig
|
||||||
@ -13,6 +13,11 @@ from vllm.forward_context import (BatchDescriptor, get_forward_context,
|
|||||||
import vllm_ascend.envs as envs_ascend
|
import vllm_ascend.envs as envs_ascend
|
||||||
from vllm_ascend.utils import enable_sp
|
from vllm_ascend.utils import enable_sp
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod
|
||||||
|
else:
|
||||||
|
WeightPrefetchMethod = None
|
||||||
|
|
||||||
|
|
||||||
class FusedMoEState(Enum):
|
class FusedMoEState(Enum):
|
||||||
AllGather = 0
|
AllGather = 0
|
||||||
@ -65,7 +70,8 @@ def set_ascend_forward_context(
|
|||||||
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||||
batch_descriptor: Optional[BatchDescriptor] = None,
|
batch_descriptor: Optional[BatchDescriptor] = None,
|
||||||
prefetch_stream: torch.npu.Stream = None,
|
prefetch_stream: torch.npu.Stream = None,
|
||||||
model_instance: torch.nn.Module = None):
|
model_instance: torch.nn.Module = None,
|
||||||
|
weight_prefetch_method: Optional[WeightPrefetchMethod] = None):
|
||||||
"""A context manager that stores the current forward context,
|
"""A context manager that stores the current forward context,
|
||||||
can be attention metadata, etc.
|
can be attention metadata, etc.
|
||||||
We add some additional param into forward_context.
|
We add some additional param into forward_context.
|
||||||
@ -127,6 +133,7 @@ def set_ascend_forward_context(
|
|||||||
hasattr(model_instance.model, "start_layer"):
|
hasattr(model_instance.model, "start_layer"):
|
||||||
forward_context.layer_idx = model_instance.model.start_layer
|
forward_context.layer_idx = model_instance.model.start_layer
|
||||||
|
|
||||||
|
# TODO(rjg-lyh): refactor mlp weight prefetch method
|
||||||
# set for mlp weight prefetch
|
# set for mlp weight prefetch
|
||||||
prefetch_mlp_enabled = envs_ascend.VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE and \
|
prefetch_mlp_enabled = envs_ascend.VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE and \
|
||||||
envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP and \
|
envs_ascend.VLLM_ASCEND_ENABLE_PREFETCH_MLP and \
|
||||||
@ -138,6 +145,8 @@ def set_ascend_forward_context(
|
|||||||
forward_context.prefetch_mlp_gate_up_proj = False
|
forward_context.prefetch_mlp_gate_up_proj = False
|
||||||
forward_context.prefetch_mlp_down_proj = False
|
forward_context.prefetch_mlp_down_proj = False
|
||||||
forward_context.prefetch_mlp_enabled = prefetch_mlp_enabled
|
forward_context.prefetch_mlp_enabled = prefetch_mlp_enabled
|
||||||
|
# TODO(yuzhup): integrate moe weight prefetch method
|
||||||
|
forward_context.weight_prefetch_method = weight_prefetch_method
|
||||||
|
|
||||||
# TODO(rjg-lyh): The current implementation is somewhat brute force and not elegant.
|
# TODO(rjg-lyh): The current implementation is somewhat brute force and not elegant.
|
||||||
# It will be improved later by implementing operator fusion through the FX graph.
|
# It will be improved later by implementing operator fusion through the FX graph.
|
||||||
|
@ -24,7 +24,7 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
|||||||
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
|
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
|
||||||
from vllm_ascend.multistream.context import get_multistream_comm_context
|
from vllm_ascend.multistream.context import get_multistream_comm_context
|
||||||
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
|
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
|
||||||
from vllm_ascend.utils import npu_prefetch
|
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
||||||
from vllm_ascend.worker.npu_input_batch import InputBatch
|
from vllm_ascend.worker.npu_input_batch import InputBatch
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -493,7 +493,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
|
|
||||||
ascend_config = get_ascend_config()
|
ascend_config = get_ascend_config()
|
||||||
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
||||||
self.enable_prefetch = ascend_config.enable_prefetch
|
self.enable_prefetch = ascend_config.weight_prefetch_config.enabled
|
||||||
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
|
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
|
||||||
|
|
||||||
vllm_config = get_current_vllm_config()
|
vllm_config = get_current_vllm_config()
|
||||||
@ -877,9 +877,9 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||||
if self.q_a_proj is not None:
|
if self.q_a_proj is not None:
|
||||||
npu_prefetch(self.q_a_proj.weight,
|
maybe_npu_prefetch(inputs=self.q_a_proj.weight,
|
||||||
hidden_states,
|
dependency=hidden_states,
|
||||||
enabled=self.enable_prefetch)
|
enabled=self.enable_prefetch)
|
||||||
ckq = self.q_a_proj(hidden_states)[0]
|
ckq = self.q_a_proj(hidden_states)[0]
|
||||||
q_c = self.q_a_layernorm(ckq)
|
q_c = self.q_a_layernorm(ckq)
|
||||||
else:
|
else:
|
||||||
@ -1005,10 +1005,10 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
current_ms_metadata = get_multistream_comm_context()
|
current_ms_metadata = get_multistream_comm_context()
|
||||||
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024
|
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024
|
||||||
if current_ms_metadata is None:
|
if current_ms_metadata is None:
|
||||||
npu_prefetch(self.o_proj.weight,
|
maybe_npu_prefetch(inputs=self.o_proj.weight,
|
||||||
o_proj_input,
|
dependency=o_proj_input,
|
||||||
max_size=MAX_O_PROJ_PREFETCH_SIZE,
|
max_size=MAX_O_PROJ_PREFETCH_SIZE,
|
||||||
enabled=self.enable_prefetch)
|
enabled=self.enable_prefetch)
|
||||||
|
|
||||||
output[...] = self.o_proj(
|
output[...] = self.o_proj(
|
||||||
o_proj_input,
|
o_proj_input,
|
||||||
@ -1016,10 +1016,10 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
is_force_scatter=self.enable_shared_expert_dp)[0]
|
is_force_scatter=self.enable_shared_expert_dp)[0]
|
||||||
else:
|
else:
|
||||||
with torch.npu.stream(current_ms_metadata.comm_stream):
|
with torch.npu.stream(current_ms_metadata.comm_stream):
|
||||||
npu_prefetch(self.o_proj.weight,
|
maybe_npu_prefetch(inputs=self.o_proj.weight,
|
||||||
o_proj_input,
|
dependency=o_proj_input,
|
||||||
max_size=MAX_O_PROJ_PREFETCH_SIZE,
|
max_size=MAX_O_PROJ_PREFETCH_SIZE,
|
||||||
enabled=self.enable_prefetch)
|
enabled=self.enable_prefetch)
|
||||||
output[...] = self.o_proj(
|
output[...] = self.o_proj(
|
||||||
o_proj_input,
|
o_proj_input,
|
||||||
is_prefill=prefill_preprocess_res is not None,
|
is_prefill=prefill_preprocess_res is not None,
|
||||||
|
@ -11,6 +11,8 @@ from vllm.utils import direct_register_custom_op
|
|||||||
|
|
||||||
import vllm_ascend.envs as envs_ascend
|
import vllm_ascend.envs as envs_ascend
|
||||||
from vllm_ascend.ascend_forward_context import MoECommType
|
from vllm_ascend.ascend_forward_context import MoECommType
|
||||||
|
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
||||||
|
from vllm_ascend.utils import npu_stream_switch, prefetch_stream
|
||||||
|
|
||||||
|
|
||||||
def _maybe_chunk_residual_impl(x: torch.Tensor,
|
def _maybe_chunk_residual_impl(x: torch.Tensor,
|
||||||
@ -148,6 +150,33 @@ def _maybe_wait_prefetch_done_impl_fake(x: torch.Tensor) -> None:
|
|||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def _prefetch_preprocess_impl(weight: torch.Tensor, start_flag: torch.Tensor,
|
||||||
|
max_weight_size: int) -> None:
|
||||||
|
calculation_stream = torch_npu.npu.current_stream()
|
||||||
|
weight_prefetch_stream = prefetch_stream()
|
||||||
|
weight_prefetch_stream.wait_stream(calculation_stream)
|
||||||
|
with npu_stream_switch(weight_prefetch_stream):
|
||||||
|
maybe_npu_prefetch(inputs=weight,
|
||||||
|
dependency=start_flag,
|
||||||
|
max_size=max_weight_size)
|
||||||
|
|
||||||
|
|
||||||
|
def _prefetch_preprocess_impl_fake(weight: torch.Tensor,
|
||||||
|
start_flag: torch.Tensor,
|
||||||
|
max_weight_size: int) -> None:
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def _prefetch_postprocess_impl(stop_flag: torch.Tensor) -> None:
|
||||||
|
calculation_stream = torch_npu.npu.current_stream()
|
||||||
|
weight_prefetch_stream = prefetch_stream()
|
||||||
|
calculation_stream.wait_stream(weight_prefetch_stream)
|
||||||
|
|
||||||
|
|
||||||
|
def _prefetch_postprocess_impl_fake(stop_flag: torch.Tensor) -> None:
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
def _maybe_all_reduce_tensor_model_parallel_impl(
|
def _maybe_all_reduce_tensor_model_parallel_impl(
|
||||||
final_hidden_states: torch.Tensor) -> torch.Tensor:
|
final_hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
forward_context = get_forward_context()
|
forward_context = get_forward_context()
|
||||||
@ -194,6 +223,18 @@ direct_register_custom_op(op_name="maybe_wait_prefetch_done",
|
|||||||
mutates_args=[],
|
mutates_args=[],
|
||||||
dispatch_key="PrivateUse1")
|
dispatch_key="PrivateUse1")
|
||||||
|
|
||||||
|
direct_register_custom_op(op_name="prefetch_preprocess",
|
||||||
|
op_func=_prefetch_preprocess_impl,
|
||||||
|
fake_impl=_prefetch_preprocess_impl_fake,
|
||||||
|
mutates_args=[],
|
||||||
|
dispatch_key="PrivateUse1")
|
||||||
|
|
||||||
|
direct_register_custom_op(op_name="prefetch_postprocess",
|
||||||
|
op_func=_prefetch_postprocess_impl,
|
||||||
|
fake_impl=_prefetch_postprocess_impl_fake,
|
||||||
|
mutates_args=[],
|
||||||
|
dispatch_key="PrivateUse1")
|
||||||
|
|
||||||
direct_register_custom_op(op_name="maybe_all_reduce_tensor_model_parallel",
|
direct_register_custom_op(op_name="maybe_all_reduce_tensor_model_parallel",
|
||||||
op_func=_maybe_all_reduce_tensor_model_parallel_impl,
|
op_func=_maybe_all_reduce_tensor_model_parallel_impl,
|
||||||
fake_impl=lambda x: x,
|
fake_impl=lambda x: x,
|
||||||
|
75
vllm_ascend/ops/weight_prefetch.py
Normal file
75
vllm_ascend/ops/weight_prefetch.py
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch_npu
|
||||||
|
|
||||||
|
from vllm_ascend.ascend_config import WeightPrefetchConfig
|
||||||
|
|
||||||
|
SUPPORTED_MODULES = ["attn", "mlp", "moe"]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModuleWeightPrefetchConfig:
|
||||||
|
module_name: str
|
||||||
|
enable: bool = False
|
||||||
|
prefetch_ratio: dict = field(default_factory=dict)
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
self.prefetch_ratio = {
|
||||||
|
prefix: ratio
|
||||||
|
for prefix, ratio in self.prefetch_ratio.items() if 0 <= ratio <= 1
|
||||||
|
}
|
||||||
|
|
||||||
|
assert self.module_name in SUPPORTED_MODULES, (
|
||||||
|
f"Invalid module name {self.module_name}, should be one of {SUPPORTED_MODULES}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.module_name in SUPPORTED_MODULES:
|
||||||
|
self.enable = self.enable and any(self.prefetch_ratio.values()) > 0
|
||||||
|
|
||||||
|
|
||||||
|
class WeightPrefetchMethod:
|
||||||
|
"""
|
||||||
|
Unified weight prefetch method.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, weight_prefetch_config: WeightPrefetchConfig) -> None:
|
||||||
|
self.attn = ModuleWeightPrefetchConfig(
|
||||||
|
module_name="attn",
|
||||||
|
enable=weight_prefetch_config.enabled,
|
||||||
|
prefetch_ratio=weight_prefetch_config.prefetch_ratio.get(
|
||||||
|
"attn", {}))
|
||||||
|
|
||||||
|
def maybe_prefetch_attn_weight_preprocess(
|
||||||
|
self, prefix: str, weight: torch.Tensor,
|
||||||
|
start_flag: torch.Tensor) -> None:
|
||||||
|
if not self.attn.enable:
|
||||||
|
return
|
||||||
|
|
||||||
|
weight_size = weight.data.element_size() * weight.data.numel(
|
||||||
|
) * self.attn.prefetch_ratio.get(prefix, 0)
|
||||||
|
|
||||||
|
torch.ops.vllm.prefetch_preprocess(weight=weight,
|
||||||
|
start_flag=start_flag,
|
||||||
|
max_weight_size=int(weight_size))
|
||||||
|
|
||||||
|
def maybe_prefetch_attn_weight_postprocess(
|
||||||
|
self, stop_flag: torch.Tensor) -> None:
|
||||||
|
if not self.attn.enable:
|
||||||
|
return
|
||||||
|
|
||||||
|
torch.ops.vllm.prefetch_postprocess(stop_flag)
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_npu_prefetch(inputs: torch.Tensor,
|
||||||
|
dependency: torch.Tensor,
|
||||||
|
max_size: int = 0,
|
||||||
|
offset: int = 0,
|
||||||
|
*,
|
||||||
|
enabled: bool = True) -> None:
|
||||||
|
if not enabled:
|
||||||
|
return
|
||||||
|
input_size = inputs.element_size() * inputs.numel()
|
||||||
|
if max_size <= 0 or max_size > input_size:
|
||||||
|
max_size = input_size
|
||||||
|
torch_npu.npu_prefetch(inputs, dependency, max_size, offset)
|
@ -21,6 +21,7 @@ import torch
|
|||||||
import torch_npu
|
import torch_npu
|
||||||
from vllm.attention.backends.abstract import AttentionType
|
from vllm.attention.backends.abstract import AttentionType
|
||||||
from vllm.distributed.parallel_state import get_ep_group
|
from vllm.distributed.parallel_state import get_ep_group
|
||||||
|
from vllm.forward_context import get_forward_context
|
||||||
|
|
||||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||||
from vllm_ascend.ops.moe.experts_selector import select_experts
|
from vllm_ascend.ops.moe.experts_selector import select_experts
|
||||||
@ -97,11 +98,32 @@ class AscendW8A8LinearMethod:
|
|||||||
tp_rank: Optional[int] = 0,
|
tp_rank: Optional[int] = 0,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if x.dtype != torch.int8:
|
if x.dtype != torch.int8:
|
||||||
|
attn_weight_map = {
|
||||||
|
"AscendQKVParallelLinear": "qkv",
|
||||||
|
"AscendRowParallelLinear": "o",
|
||||||
|
}
|
||||||
|
layer_cls_name = layer.__class__.__name__
|
||||||
|
weight_prefetch_method = get_forward_context(
|
||||||
|
).weight_prefetch_method
|
||||||
|
assert weight_prefetch_method is not None
|
||||||
|
|
||||||
|
# prefetch_qkvo_proj.weight preprocess
|
||||||
|
weight_prefetch_method.maybe_prefetch_attn_weight_preprocess(
|
||||||
|
prefix=attn_weight_map.get(layer_cls_name, ""),
|
||||||
|
weight=layer.weight,
|
||||||
|
start_flag=x,
|
||||||
|
)
|
||||||
|
# quant
|
||||||
x = quant_per_tensor(
|
x = quant_per_tensor(
|
||||||
x,
|
x,
|
||||||
layer.aclnn_input_scale_reciprocal,
|
layer.aclnn_input_scale_reciprocal,
|
||||||
layer.aclnn_input_offset,
|
layer.aclnn_input_offset,
|
||||||
)
|
)
|
||||||
|
# prefetch_qkvo_proj.weight postprocess
|
||||||
|
if layer_cls_name in attn_weight_map.keys():
|
||||||
|
weight_prefetch_method.maybe_prefetch_attn_weight_postprocess(
|
||||||
|
x)
|
||||||
|
|
||||||
quant_bias = layer.quant_bias if tp_rank == 0 else None
|
quant_bias = layer.quant_bias if tp_rank == 0 else None
|
||||||
if is_310p():
|
if is_310p():
|
||||||
# On 300I Duo platform, we need transpose again if
|
# On 300I Duo platform, we need transpose again if
|
||||||
|
@ -70,11 +70,12 @@ from vllm.sequence import IntermediateTensors
|
|||||||
from vllm_ascend import envs
|
from vllm_ascend import envs
|
||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
from vllm_ascend.models.layers.sfa import Indexer
|
from vllm_ascend.models.layers.sfa import Indexer
|
||||||
|
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
||||||
from vllm_ascend.quantization.quant_config import AscendLinearMethod
|
from vllm_ascend.quantization.quant_config import AscendLinearMethod
|
||||||
from vllm_ascend.torchair.ops.torchair_fused_moe import TorchairAscendFusedMoE
|
from vllm_ascend.torchair.ops.torchair_fused_moe import TorchairAscendFusedMoE
|
||||||
from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import \
|
from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import \
|
||||||
TorchairAscendW8A8DynamicLinearMethod
|
TorchairAscendW8A8DynamicLinearMethod
|
||||||
from vllm_ascend.utils import dispose_tensor, npu_prefetch, oproj_tp_enable
|
from vllm_ascend.utils import dispose_tensor, oproj_tp_enable
|
||||||
|
|
||||||
|
|
||||||
class TorchairDeepseekV2SiluAndMul(SiluAndMul):
|
class TorchairDeepseekV2SiluAndMul(SiluAndMul):
|
||||||
@ -589,9 +590,9 @@ class TorchairDeepseekV2MLAAttention(DeepseekV2MLAAttention):
|
|||||||
and attn_metadata.num_decodes > 0)
|
and attn_metadata.num_decodes > 0)
|
||||||
forward_kwargs = {"enable_multistream_mla": enable_multistream_mla}
|
forward_kwargs = {"enable_multistream_mla": enable_multistream_mla}
|
||||||
if self.q_lora_rank is not None:
|
if self.q_lora_rank is not None:
|
||||||
npu_prefetch(self.q_a_proj.weight,
|
maybe_npu_prefetch(self.q_a_proj.weight,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
enabled=enable_multistream_mla)
|
enabled=enable_multistream_mla)
|
||||||
ckq = self.q_a_proj(hidden_states)[0]
|
ckq = self.q_a_proj(hidden_states)[0]
|
||||||
hidden_states_or_q_c = self.q_a_layernorm(ckq)
|
hidden_states_or_q_c = self.q_a_layernorm(ckq)
|
||||||
forward_kwargs['ckq'] = ckq
|
forward_kwargs['ckq'] = ckq
|
||||||
|
@ -23,9 +23,9 @@ from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
|||||||
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
|
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
|
||||||
from vllm_ascend.multistream.context import get_multistream_comm_context
|
from vllm_ascend.multistream.context import get_multistream_comm_context
|
||||||
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
|
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
|
||||||
|
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
||||||
from vllm_ascend.torchair.utils import (TorchairCommonAttentionMetadata,
|
from vllm_ascend.torchair.utils import (TorchairCommonAttentionMetadata,
|
||||||
npu_stream_switch, npu_wait_tensor)
|
npu_stream_switch, npu_wait_tensor)
|
||||||
from vllm_ascend.utils import npu_prefetch
|
|
||||||
from vllm_ascend.worker.npu_input_batch import InputBatch
|
from vllm_ascend.worker.npu_input_batch import InputBatch
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -684,10 +684,10 @@ class AscendMLATorchairImpl(MLAAttentionImpl):
|
|||||||
if hasattr(self, "running_in_graph") and not self.running_in_graph:
|
if hasattr(self, "running_in_graph") and not self.running_in_graph:
|
||||||
return x
|
return x
|
||||||
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 # 16MB
|
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 # 16MB
|
||||||
npu_prefetch(self.o_proj.weight,
|
maybe_npu_prefetch(self.o_proj.weight,
|
||||||
x,
|
x,
|
||||||
max_size=MAX_O_PROJ_PREFETCH_SIZE,
|
max_size=MAX_O_PROJ_PREFETCH_SIZE,
|
||||||
enabled=enable_multistream_mla)
|
enabled=enable_multistream_mla)
|
||||||
return self.o_proj(x, is_prefill=False)[0]
|
return self.o_proj(x, is_prefill=False)[0]
|
||||||
|
|
||||||
# Return `ql_nope`, `q_pe`
|
# Return `ql_nope`, `q_pe`
|
||||||
@ -1281,10 +1281,10 @@ class AscendMLATorchairImpl(MLAAttentionImpl):
|
|||||||
current_ms_metadata = get_multistream_comm_context()
|
current_ms_metadata = get_multistream_comm_context()
|
||||||
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 # 16MB
|
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 # 16MB
|
||||||
if current_ms_metadata is None:
|
if current_ms_metadata is None:
|
||||||
npu_prefetch(self.o_proj.weight,
|
maybe_npu_prefetch(self.o_proj.weight,
|
||||||
o_proj_input,
|
o_proj_input,
|
||||||
max_size=MAX_O_PROJ_PREFETCH_SIZE,
|
max_size=MAX_O_PROJ_PREFETCH_SIZE,
|
||||||
enabled=enable_multistream_mla)
|
enabled=enable_multistream_mla)
|
||||||
|
|
||||||
output[...] = self.o_proj(
|
output[...] = self.o_proj(
|
||||||
o_proj_input,
|
o_proj_input,
|
||||||
@ -1292,10 +1292,10 @@ class AscendMLATorchairImpl(MLAAttentionImpl):
|
|||||||
is_force_scatter=self.enable_shared_expert_dp)[0]
|
is_force_scatter=self.enable_shared_expert_dp)[0]
|
||||||
else:
|
else:
|
||||||
with torch.npu.stream(current_ms_metadata.comm_stream):
|
with torch.npu.stream(current_ms_metadata.comm_stream):
|
||||||
npu_prefetch(self.o_proj.weight,
|
maybe_npu_prefetch(self.o_proj.weight,
|
||||||
o_proj_input,
|
o_proj_input,
|
||||||
max_size=MAX_O_PROJ_PREFETCH_SIZE,
|
max_size=MAX_O_PROJ_PREFETCH_SIZE,
|
||||||
enabled=enable_multistream_mla)
|
enabled=enable_multistream_mla)
|
||||||
output[...] = self.o_proj(
|
output[...] = self.o_proj(
|
||||||
o_proj_input,
|
o_proj_input,
|
||||||
is_prefill=True,
|
is_prefill=True,
|
||||||
|
@ -51,6 +51,7 @@ _CUSTOM_OP_ENABLED = None
|
|||||||
_IS_310P = None
|
_IS_310P = None
|
||||||
_SLEEP_MODE_ENABLED = None
|
_SLEEP_MODE_ENABLED = None
|
||||||
_CURRENT_STREAM = None
|
_CURRENT_STREAM = None
|
||||||
|
_PREFETCH_STREAM = None
|
||||||
_ASCEND_CUSTOMOP_IS_REIGISTERED = False
|
_ASCEND_CUSTOMOP_IS_REIGISTERED = False
|
||||||
|
|
||||||
|
|
||||||
@ -241,6 +242,15 @@ def current_stream() -> torch.npu.Stream:
|
|||||||
return _CURRENT_STREAM
|
return _CURRENT_STREAM
|
||||||
|
|
||||||
|
|
||||||
|
def prefetch_stream() -> torch.npu.Stream:
|
||||||
|
global _PREFETCH_STREAM
|
||||||
|
if _PREFETCH_STREAM is None:
|
||||||
|
# when this function is called before any stream is set,
|
||||||
|
# we return the default stream.
|
||||||
|
_PREFETCH_STREAM = torch_npu.npu.Stream()
|
||||||
|
return _PREFETCH_STREAM
|
||||||
|
|
||||||
|
|
||||||
def adapt_patch(is_global_patch: bool = False):
|
def adapt_patch(is_global_patch: bool = False):
|
||||||
if is_global_patch:
|
if is_global_patch:
|
||||||
from vllm_ascend.patch import platform # noqa: F401
|
from vllm_ascend.patch import platform # noqa: F401
|
||||||
@ -446,20 +456,6 @@ class ProfileExecuteDuration:
|
|||||||
return durations
|
return durations
|
||||||
|
|
||||||
|
|
||||||
# TODO(wxy): Move to ops module
|
|
||||||
def npu_prefetch(input: torch.Tensor,
|
|
||||||
dependency: torch.Tensor,
|
|
||||||
max_size: int = 0,
|
|
||||||
*,
|
|
||||||
enabled: bool = True):
|
|
||||||
if not enabled:
|
|
||||||
return
|
|
||||||
input_size = input.element_size() * input.numel()
|
|
||||||
if max_size <= 0 or max_size > input_size:
|
|
||||||
max_size = input_size
|
|
||||||
torch_npu.npu_prefetch(input, dependency, max_size)
|
|
||||||
|
|
||||||
|
|
||||||
# TODO(ttanzhiqiang): rm_router_logits
|
# TODO(ttanzhiqiang): rm_router_logits
|
||||||
# dp>1 will trigger
|
# dp>1 will trigger
|
||||||
# In theory, this solution is only applicable to AllGather and AllGatherEP, because in the dp scenario, the previous operation was gate + two communications, and now it is changed to one communication + gate operation, which can save some communication time. In theory, all moe AllGather and AllGatherEP solutions can follow this logic, but now other moe models (qwen3-235b) dp solutions are not adjusted, so use the switch to control it to prevent code errors.
|
# In theory, this solution is only applicable to AllGather and AllGatherEP, because in the dp scenario, the previous operation was gate + two communications, and now it is changed to one communication + gate operation, which can save some communication time. In theory, all moe AllGather and AllGatherEP solutions can follow this logic, but now other moe models (qwen3-235b) dp solutions are not adjusted, so use the switch to control it to prevent code errors.
|
||||||
|
@ -113,6 +113,7 @@ from vllm_ascend.eplb.eplb_updator import EplbUpdator
|
|||||||
from vllm_ascend.eplb.utils import model_register
|
from vllm_ascend.eplb.utils import model_register
|
||||||
from vllm_ascend.models.layers.mla import AscendMultiHeadLatentAttention
|
from vllm_ascend.models.layers.mla import AscendMultiHeadLatentAttention
|
||||||
from vllm_ascend.multistream.ms_split import compute_split_seq_index
|
from vllm_ascend.multistream.ms_split import compute_split_seq_index
|
||||||
|
from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod
|
||||||
from vllm_ascend.platform import NPUPlatform
|
from vllm_ascend.platform import NPUPlatform
|
||||||
from vllm_ascend.sample.logits_processor import build_logitsprocs
|
from vllm_ascend.sample.logits_processor import build_logitsprocs
|
||||||
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
|
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
|
||||||
@ -285,6 +286,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.chunked_prefill_enabled = self.scheduler_config.chunked_prefill_enabled
|
self.chunked_prefill_enabled = self.scheduler_config.chunked_prefill_enabled
|
||||||
else:
|
else:
|
||||||
self.chunked_prefill_enabled = True
|
self.chunked_prefill_enabled = True
|
||||||
|
self.weight_prefetch_method = WeightPrefetchMethod(
|
||||||
|
self.ascend_config.weight_prefetch_config)
|
||||||
|
|
||||||
if self.cache_config.cache_dtype == "auto":
|
if self.cache_config.cache_dtype == "auto":
|
||||||
self.kv_cache_dtype = self.dtype
|
self.kv_cache_dtype = self.dtype
|
||||||
@ -1856,7 +1859,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
num_actual_tokens=scheduler_output.
|
num_actual_tokens=scheduler_output.
|
||||||
total_num_scheduled_tokens,
|
total_num_scheduled_tokens,
|
||||||
prefetch_stream=self.prefetch_stream,
|
prefetch_stream=self.prefetch_stream,
|
||||||
model_instance=self.model):
|
model_instance=self.model,
|
||||||
|
weight_prefetch_method=self.weight_prefetch_method):
|
||||||
self.maybe_setup_kv_connector(scheduler_output)
|
self.maybe_setup_kv_connector(scheduler_output)
|
||||||
|
|
||||||
hidden_states = self._generate_process_reqs_hidden_states(
|
hidden_states = self._generate_process_reqs_hidden_states(
|
||||||
@ -2370,7 +2374,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||||||
batch_descriptor=batch_descriptor,
|
batch_descriptor=batch_descriptor,
|
||||||
prefetch_stream=self.prefetch_stream,
|
prefetch_stream=self.prefetch_stream,
|
||||||
model_instance=self.model):
|
model_instance=self.model,
|
||||||
|
weight_prefetch_method=self.weight_prefetch_method):
|
||||||
hidden_states = self._generate_dummy_run_hidden_states(
|
hidden_states = self._generate_dummy_run_hidden_states(
|
||||||
with_prefill, is_torchair_compile, input_ids, positions,
|
with_prefill, is_torchair_compile, input_ids, positions,
|
||||||
attn_metadata, num_tokens, intermediate_tensors,
|
attn_metadata, num_tokens, intermediate_tensors,
|
||||||
|
Reference in New Issue
Block a user