Files
vllm-ascend/vllm_ascend/ops/weight_prefetch.py
Ruri ff37575936 [1/N][Feat] Add weight prefetch feature for Attention layers (#3146)
### What this PR does / why we need it?

- Refacotr and integrate a unified `WeightPrefetchMethod`
- Integrate `qkv_proj.weight` and `o_proj.weight` in quantized Attention
modules
- Prefetching these weights ahead of matmul-like operators imporves
performance by reducing L2 cache transfer latency

### Does this PR introduce _any_ user-facing change?

Add a new config in `--additional-config` for configuration:
```json
{
    "weight_prefetch_config": {
        "enabled": false,
        "prefetch_ratio": {
            "attn": {
                "qkv": 1.0,
                "o": 1.0,
            },
        },
    },
}
```
This feature is enabled by default, and can be disabled through this
configuration

### How was this patch tested?


- vLLM version: v0.11.0

---------

Signed-off-by: yuzhup <15705211260@163.com>
Signed-off-by: zhoux77899 <zhouxiang100@huawei.com>
Co-authored-by: yuzhup <15705211260@163.com>
2025-10-09 20:38:39 +08:00

76 lines
2.4 KiB
Python

from dataclasses import dataclass, field
import torch
import torch_npu
from vllm_ascend.ascend_config import WeightPrefetchConfig
SUPPORTED_MODULES = ["attn", "mlp", "moe"]
@dataclass
class ModuleWeightPrefetchConfig:
module_name: str
enable: bool = False
prefetch_ratio: dict = field(default_factory=dict)
def __post_init__(self) -> None:
self.prefetch_ratio = {
prefix: ratio
for prefix, ratio in self.prefetch_ratio.items() if 0 <= ratio <= 1
}
assert self.module_name in SUPPORTED_MODULES, (
f"Invalid module name {self.module_name}, should be one of {SUPPORTED_MODULES}"
)
if self.module_name in SUPPORTED_MODULES:
self.enable = self.enable and any(self.prefetch_ratio.values()) > 0
class WeightPrefetchMethod:
"""
Unified weight prefetch method.
"""
def __init__(self, weight_prefetch_config: WeightPrefetchConfig) -> None:
self.attn = ModuleWeightPrefetchConfig(
module_name="attn",
enable=weight_prefetch_config.enabled,
prefetch_ratio=weight_prefetch_config.prefetch_ratio.get(
"attn", {}))
def maybe_prefetch_attn_weight_preprocess(
self, prefix: str, weight: torch.Tensor,
start_flag: torch.Tensor) -> None:
if not self.attn.enable:
return
weight_size = weight.data.element_size() * weight.data.numel(
) * self.attn.prefetch_ratio.get(prefix, 0)
torch.ops.vllm.prefetch_preprocess(weight=weight,
start_flag=start_flag,
max_weight_size=int(weight_size))
def maybe_prefetch_attn_weight_postprocess(
self, stop_flag: torch.Tensor) -> None:
if not self.attn.enable:
return
torch.ops.vllm.prefetch_postprocess(stop_flag)
def maybe_npu_prefetch(inputs: torch.Tensor,
dependency: torch.Tensor,
max_size: int = 0,
offset: int = 0,
*,
enabled: bool = True) -> None:
if not enabled:
return
input_size = inputs.element_size() * inputs.numel()
if max_size <= 0 or max_size > input_size:
max_size = input_size
torch_npu.npu_prefetch(inputs, dependency, max_size, offset)