Support multistream of shared experts in FusedMoE (#997)

Contains on #1111 for completeness.

<!--  Thanks for sending a pull request!

BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html

-->
### What this PR does / why we need it?
Implement multi-stream parallelism for MoE layers with shared experts,
where computation of shared experts will be overlapped with expert token
dispatch and combine. Also, when multi-stream is enabled, weights of
shared experts will be force to replicate across all cards, regardless
of any tensor parallelism configurations, to avoid AllReduce operations.

With the expected overlaping being:
```
| shared gate_up | shared act |              | shared down |
|    dispatch    | routed gate_up, act, down |   combine   |
```

<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.

- Please clarify why the changes are needed. For instance, the use case
and bug description.

- Fixes #
-->

### Does this PR introduce _any_ user-facing change?
No.

<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->

### How was this patch tested?
Tested on 1x16 910 node, with tailored 2 layer DSKv2.
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->

---------

Signed-off-by: sdmyzlp <lrwei2@petalmail.com>
This commit is contained in:
sdmyzlp
2025-06-11 09:18:38 +08:00
committed by GitHub
parent 04abfd8721
commit 7bdc606677
11 changed files with 296 additions and 308 deletions

View File

@ -188,6 +188,7 @@ jobs:
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_QwQ VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_QwQ
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_topk VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_topk
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W8A8
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/ --ignore=tests/multicard/test_ilama_lora_tp2.py --ignore=tests/multicard/test_offline_inference_distributed.py VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/ --ignore=tests/multicard/test_ilama_lora_tp2.py --ignore=tests/multicard/test_offline_inference_distributed.py
fi fi
@ -218,5 +219,6 @@ jobs:
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_QwQ VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_QwQ
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_topk VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_topk
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W8A8
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/ --ignore=tests/multicard/test_ilama_lora_tp2.py --ignore=tests/multicard/test_offline_inference_distributed.py VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/ --ignore=tests/multicard/test_ilama_lora_tp2.py --ignore=tests/multicard/test_offline_inference_distributed.py
fi fi

View File

@ -39,11 +39,11 @@ The details of each config option are as follows:
| Name | Type | Default | Description | | Name | Type | Default | Description |
| ---- | ---- | ------- | ----------- | | ---- | ---- | ------- | ----------- |
| `enabled` | bool | `False` | Whether to enable torchair graph mode | | `enabled` | bool | `False` | Whether to enable torchair graph mode |
| `enable_multistream_moe`| bool | `False` | Whether to enable multistream shared expert |
| `enable_view_optimize` | bool | `True` | Whether to enable torchair view optimization | | `enable_view_optimize` | bool | `True` | Whether to enable torchair view optimization |
| `use_cached_graph` | bool | `False` | Whether to use cached graph | | `use_cached_graph` | bool | `False` | Whether to use cached graph |
| `graph_batch_sizes` | list[int] | `[]` | The batch size for torchair graph cache | | `graph_batch_sizes` | list[int] | `[]` | The batch size for torchair graph cache |
| `graph_batch_sizes_init` | bool | `False` | Init graph batch size dynamically if `graph_batch_sizes` is empty | | `graph_batch_sizes_init` | bool | `False` | Init graph batch size dynamically if `graph_batch_sizes` is empty |
| `enable_multistream_shared_expert`| bool | `False` | Whether to enable multistream shared expert |
**ascend_scheduler_config** **ascend_scheduler_config**
@ -64,7 +64,7 @@ A full example of additional configuration is as follows:
"use_cached_graph": true, "use_cached_graph": true,
"graph_batch_sizes": [1, 2, 4, 8], "graph_batch_sizes": [1, 2, 4, 8],
"graph_batch_sizes_init": false, "graph_batch_sizes_init": false,
"enable_multistream_shared_expert": false "enable_multistream_moe": false
}, },
"ascend_scheduler_config": { "ascend_scheduler_config": {
"enabled": true, "enabled": true,

View File

@ -6,6 +6,9 @@ warn_unused_configs = True
[mypy-torch_npu.*] [mypy-torch_npu.*]
ignore_missing_imports = True ignore_missing_imports = True
[mypy-torchair.*]
ignore_missing_imports = True
[mypy-transformers.*] [mypy-transformers.*]
ignore_missing_imports = True ignore_missing_imports = True

View File

@ -23,7 +23,7 @@ Run `pytest tests/test_offline_inference.py`.
import os import os
from unittest.mock import patch from unittest.mock import patch
import vllm # noqa: F401 from modelscope import snapshot_download # type: ignore
from vllm import SamplingParams from vllm import SamplingParams
from tests.conftest import VllmRunner from tests.conftest import VllmRunner
@ -95,3 +95,20 @@ def test_models_distributed_DeepSeek_dbo():
distributed_executor_backend="mp", distributed_executor_backend="mp",
) as vllm_model: ) as vllm_model:
vllm_model.generate(example_prompts, sampling_params) vllm_model.generate(example_prompts, sampling_params)
def test_models_distributed_DeepSeek_W8A8():
example_prompts = [
"Hello, my name is",
]
max_tokens = 5
with VllmRunner(
snapshot_download("vllm-ascend/DeepSeek-V2-Lite-W8A8"),
max_model_len=8192,
enforce_eager=True,
dtype="auto",
tensor_parallel_size=4,
quantization="ascend",
) as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens)

View File

@ -58,7 +58,7 @@ def test_run_with_ascend_config():
"use_cached_graph": True, "use_cached_graph": True,
"graph_batch_sizes": [1, 2, 4, 8], "graph_batch_sizes": [1, 2, 4, 8],
"graph_batch_sizes_init": False, "graph_batch_sizes_init": False,
"enable_multistream_shared_expert": True, "enable_multistream_moe": True,
}, },
"ascend_scheduler_config": { "ascend_scheduler_config": {
"enabled": True, "enabled": True,
@ -79,7 +79,7 @@ def test_run_with_ascend_config():
1, 2, 4, 8 1, 2, 4, 8
] ]
assert not ascend_config.torchair_graph_config.graph_batch_sizes_init assert not ascend_config.torchair_graph_config.graph_batch_sizes_init
assert ascend_config.torchair_graph_config.enable_multistream_shared_expert assert ascend_config.torchair_graph_config.enable_multistream_moe
assert ascend_config.ascend_scheduler_config.enabled assert ascend_config.ascend_scheduler_config.enabled
assert ascend_config.ascend_scheduler_config.enable_chunked_prefill assert ascend_config.ascend_scheduler_config.enable_chunked_prefill
assert ascend_config.expert_tensor_parallel_size == 1 assert ascend_config.expert_tensor_parallel_size == 1

View File

@ -54,8 +54,8 @@ class TorchairGraphConfig:
"graph_batch_sizes", []) "graph_batch_sizes", [])
self.graph_batch_sizes_init = torchair_graph_config.get( self.graph_batch_sizes_init = torchair_graph_config.get(
"graph_batch_sizes_init", False) "graph_batch_sizes_init", False)
self.enable_multistream_shared_expert = torchair_graph_config.get( self.enable_multistream_moe = torchair_graph_config.get(
"enable_multistream_shared_expert", False) "enable_multistream_moe", False)
self.enable_view_optimize = torchair_graph_config.get( self.enable_view_optimize = torchair_graph_config.get(
"enable_view_optimize", True) "enable_view_optimize", True)

View File

@ -29,7 +29,7 @@ from typing import Any, Dict, List, Optional, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch_npu import torch_npu # noqa: F401
import vllm.envs as envs import vllm.envs as envs
from torch import nn from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
@ -40,13 +40,10 @@ from vllm.distributed import (get_pp_group,
get_tp_group, tensor_model_parallel_all_reduce) get_tp_group, tensor_model_parallel_all_reduce)
from vllm.distributed.parallel_state import get_dp_group from vllm.distributed.parallel_state import get_dp_group
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
ReplicatedLinear, ReplicatedLinear,
RowParallelLinear, RowParallelLinear)
UnquantizedLinearMethod)
from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
@ -67,6 +64,7 @@ from vllm.sequence import IntermediateTensors
import vllm_ascend.envs as envs_ascend import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.models.deepseek_v2 import CustomDeepseekV2MLP
from vllm_ascend.multistream.base import MSEventKey from vllm_ascend.multistream.base import MSEventKey
from vllm_ascend.multistream.context import ( from vllm_ascend.multistream.context import (
advance_step_multistream_layer_context, get_multistream_comm_context, advance_step_multistream_layer_context, get_multistream_comm_context,
@ -78,117 +76,17 @@ from vllm_ascend.multistream.metadata import (MultiStreamConfig,
make_multistream_metadata_ds) make_multistream_metadata_ds)
from vllm_ascend.multistream.ms_split import compute_split_seq_index from vllm_ascend.multistream.ms_split import compute_split_seq_index
from vllm_ascend.ops.fused_moe import AscendFusedMoE from vllm_ascend.ops.fused_moe import AscendFusedMoE
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
from vllm_ascend.utils import dispose_tensor from vllm_ascend.utils import dispose_tensor
VLLM_ASCEND_ENABLE_DBO: bool = envs_ascend.VLLM_ASCEND_ENABLE_DBO VLLM_ASCEND_ENABLE_DBO: bool = envs_ascend.VLLM_ASCEND_ENABLE_DBO
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2 VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
class CustomDeepseekDBOMLP(nn.Module): class CustomDeepseekDBOMLP(CustomDeepseekV2MLP):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True,
prefix: str = "",
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj")
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
reduce_results=reduce_results,
prefix=f"{prefix}.down_proj")
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
self.act_fn = SiluAndMul()
# NOTE: `torch_npu.npu_dequant_swiglu_quant` can only be enabled in dynamic quant
self.is_dynamic_quant = not isinstance(
self.gate_up_proj.quant_method,
UnquantizedLinearMethod) and isinstance(
self.gate_up_proj.quant_method.quant_method,
AscendW8A8DynamicLinearMethod)
def forward(self, x):
if self.is_dynamic_quant:
x, dynamic_scale = torch_npu.npu_dynamic_quant(x)
x = torch_npu.npu_quant_matmul(
x,
self.gate_up_proj.weight,
self.gate_up_proj.weight_scale,
output_dtype=torch.int32,
)
x, dynamic_scale = torch_npu.npu_dequant_swiglu_quant(
x=x,
weight_scale=self.gate_up_proj.weight_scale_fp32,
activation_scale=dynamic_scale,
bias=None,
quant_scale=None,
quant_offset=None,
group_index=None,
activate_left=True,
quant_mode=1)
x = torch_npu.npu_quant_matmul(
x,
self.down_proj.weight,
self.down_proj.weight_scale,
pertoken_scale=dynamic_scale,
output_dtype=torch.bfloat16,
)
if self.down_proj.reduce_results and self.down_proj.tp_size > 1:
x = tensor_model_parallel_all_reduce(x)
return x
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
def _forward_ms_mlp(self, x): def _forward_ms_mlp(self, x):
current_ms_metadata = get_multistream_comm_context() current_ms_metadata = get_multistream_comm_context()
assert current_ms_metadata is not None assert current_ms_metadata is not None
if self.is_dynamic_quant:
x, dynamic_scale = torch_npu.npu_dynamic_quant(x)
x = torch_npu.npu_quant_matmul(
x,
self.gate_up_proj.weight,
self.gate_up_proj.weight_scale,
output_dtype=torch.int32,
)
x, dynamic_scale = torch_npu.npu_dequant_swiglu_quant(
x=x,
weight_scale=self.gate_up_proj.weight_scale_fp32,
activation_scale=dynamic_scale,
bias=None,
quant_scale=None,
quant_offset=None,
group_index=None,
activate_left=True,
quant_mode=1)
x = torch_npu.npu_quant_matmul(
x,
self.down_proj.weight,
self.down_proj.weight_scale,
pertoken_scale=dynamic_scale,
output_dtype=torch.bfloat16,
)
if self.down_proj.reduce_results and self.down_proj.tp_size > 1:
current_ms_metadata.before_comm_event.record()
with torch.npu.stream(current_ms_metadata.comm_stream):
current_ms_metadata.before_comm_event.wait()
x = tensor_model_parallel_all_reduce(x)
current_ms_metadata.after_comm_event.record()
return x
gate_up, _ = self.gate_up_proj(x) gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up) x = self.act_fn(gate_up)
current_ms_metadata.before_comm_event.record() current_ms_metadata.before_comm_event.record()

View File

@ -25,7 +25,7 @@
# # vllm-project/vllm/vllm/model_executor/models/deepseek_v2.py # # vllm-project/vllm/vllm/model_executor/models/deepseek_v2.py
# """Inference-only DeepseekV2/DeepseekV3 model.""" # """Inference-only DeepseekV2/DeepseekV3 model."""
from typing import Any, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -69,12 +69,73 @@ import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.parallel_state import get_ep_group from vllm_ascend.distributed.parallel_state import get_ep_group
from vllm_ascend.ops.fused_moe import AscendFusedMoE from vllm_ascend.ops.fused_moe import AscendFusedMoE
from vllm_ascend.quantization.quant_config import AscendLinearMethod
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
from vllm_ascend.utils import dispose_tensor from vllm_ascend.utils import dispose_tensor
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2 VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
class CustomDeepseekV2SiluAndMul(SiluAndMul):
def __init__(self,
*,
weight_scale: Optional[Callable[[], torch.Tensor]] = None):
super().__init__()
self.weight_scale = weight_scale
def forward_oot(self, x: Union[torch.Tensor, Tuple[torch.Tensor,
torch.Tensor]]):
if isinstance(x, tuple):
assert self.weight_scale is not None
# For AscendW8A8DynamicLinearMethod:
# a dynamic scale is passed along with the quantized value.
quantized_x, dynamic_scale = x
return torch_npu.npu_dequant_swiglu_quant(
x=quantized_x,
weight_scale=self.weight_scale(),
activation_scale=dynamic_scale,
activate_left=True,
quant_mode=1)
else:
return super().forward_oot(x)
class CustomDeepseekV2MergedReplicatedLinear(ReplicatedLinear):
def __init__(
self,
input_size: int,
output_sizes: list[int],
bias: bool = True,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
self.output_sizes = output_sizes
super().__init__(input_size,
sum(output_sizes),
bias=bias,
quant_config=quant_config,
prefix=prefix)
def weight_loader(self, param: torch.nn.Parameter,
loaded_weight: torch.Tensor, loaded_shard_id: int):
# With no support for GGUF format yet.
assert not getattr(param, "is_gguf_weight", False)
assert not getattr(param, "is_gguf_weight_type", False)
assert loaded_shard_id < len(self.output_sizes)
shard_offset = sum(self.output_sizes[:loaded_shard_id])
shard_size = self.output_sizes[loaded_shard_id]
shard = param.data.narrow(param.output_dim, shard_offset, shard_size)
assert shard.size() == loaded_weight.size(), (
f"Tried to load weights of size {loaded_weight.size()}"
f"to a parameter shard of id {loaded_shard_id} size {shard.size()}"
)
shard.copy_(loaded_weight)
class CustomDeepseekV2MLP(nn.Module): class CustomDeepseekV2MLP(nn.Module):
def __init__( def __init__(
@ -84,61 +145,68 @@ class CustomDeepseekV2MLP(nn.Module):
hidden_act: str, hidden_act: str,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True, reduce_results: bool = True,
force_replicate: bool = False,
prefix: str = "", prefix: str = "",
) -> None: ) -> None:
super().__init__() super().__init__()
self.gate_up_proj = MergedColumnParallelLinear( if not force_replicate:
hidden_size, [intermediate_size] * 2, self.gate_up_proj = MergedColumnParallelLinear(
bias=False, hidden_size, [intermediate_size] * 2,
quant_config=quant_config, bias=False,
prefix=f"{prefix}.gate_up_proj") quant_config=quant_config,
self.down_proj = RowParallelLinear(intermediate_size, prefix=f"{prefix}.gate_up_proj")
hidden_size, self.down_proj = RowParallelLinear(intermediate_size,
bias=False, hidden_size,
quant_config=quant_config, bias=False,
reduce_results=reduce_results, quant_config=quant_config,
prefix=f"{prefix}.down_proj") reduce_results=reduce_results,
prefix=f"{prefix}.down_proj")
else:
self.gate_up_proj = CustomDeepseekV2MergedReplicatedLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj")
self.down_proj = ReplicatedLinear(intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.down_proj")
if hidden_act != "silu": if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. " raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.") "Only silu is supported for now.")
self.act_fn = SiluAndMul()
# NOTE: `torch_npu.npu_dequant_swiglu_quant` can only be enabled in dynamic quant quant_method = self.gate_up_proj.quant_method
self.is_dynamic_quant = not isinstance( if isinstance(quant_method, UnquantizedLinearMethod):
self.gate_up_proj.quant_method, self.act_fn = CustomDeepseekV2SiluAndMul()
UnquantizedLinearMethod) and isinstance( elif (isinstance(quant_method, AscendLinearMethod) and isinstance(
self.gate_up_proj.quant_method.quant_method, quant_method.quant_method, AscendW8A8DynamicLinearMethod)):
AscendW8A8DynamicLinearMethod) # TODO(sdmyzlp): Currently preserved as before:
# 1. The only quantization supported for silu is W8A8Dynamic
# 2. Output dtype of gate_up/down is fixed to be int32/bfloat16
#
# Maybe one can implement a better and more general configuration
# scheme, e.g. by somehow passing around the tweaked `quant_config`
self.act_fn = CustomDeepseekV2SiluAndMul(
# Use lazy binding, for `weight_scale_fp32` is accessible
# only after `process_weights_after_loading`.
weight_scale=lambda: self.gate_up_proj.weight_scale_fp32)
# To be consumed by AscendW8A8DynamicLinearMethod.apply()
self.gate_up_proj._ascend_quant_config = {
"output_dtype": torch.int32,
"pertoken_scale": False,
"return_scale": True,
}
self.down_proj._ascend_quant_config = {
"output_dtype": torch.bfloat16,
"pertoken_scale": True,
"return_scale": False,
}
else:
raise NotImplementedError(
f"Quantization with [{type(quant_method)}] is NOT supported")
def forward(self, x): def forward(self, x):
if self.is_dynamic_quant:
x, dynamic_scale = torch_npu.npu_dynamic_quant(x)
x = torch_npu.npu_quant_matmul(
x,
self.gate_up_proj.weight,
self.gate_up_proj.weight_scale,
output_dtype=torch.int32,
)
x, dynamic_scale = torch_npu.npu_dequant_swiglu_quant(
x=x,
weight_scale=self.gate_up_proj.weight_scale_fp32,
activation_scale=dynamic_scale,
bias=None,
quant_scale=None,
quant_offset=None,
group_index=None,
activate_left=True,
quant_mode=1)
x = torch_npu.npu_quant_matmul(
x,
self.down_proj.weight,
self.down_proj.weight_scale,
pertoken_scale=dynamic_scale,
output_dtype=torch.bfloat16,
)
if self.down_proj.reduce_results and self.down_proj.tp_size > 1:
x = tensor_model_parallel_all_reduce(x)
return x
gate_up, _ = self.gate_up_proj(x) gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up) x = self.act_fn(gate_up)
x, _ = self.down_proj(x) x, _ = self.down_proj(x)
@ -169,6 +237,12 @@ class CustomDeepseekV2MoE(nn.Module):
raise ValueError(f"Unsupported activation: {config.hidden_act}. " raise ValueError(f"Unsupported activation: {config.hidden_act}. "
"Only silu is supported for now.") "Only silu is supported for now.")
ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
# NOTE: multistream only effective when `VLLM_ENABLE_MC2` is on
self.enable_multistream_moe = \
ascend_config.torchair_graph_config.enable_multistream_moe and VLLM_ENABLE_MC2
self.gate = ReplicatedLinear(config.hidden_size, self.gate = ReplicatedLinear(config.hidden_size,
config.n_routed_experts, config.n_routed_experts,
bias=False, bias=False,
@ -204,8 +278,11 @@ class CustomDeepseekV2MoE(nn.Module):
hidden_act=config.hidden_act, hidden_act=config.hidden_act,
quant_config=quant_config, quant_config=quant_config,
reduce_results=True, reduce_results=True,
force_replicate=self.enable_multistream_moe,
prefix=f"{prefix}.shared_experts", prefix=f"{prefix}.shared_experts",
) )
else:
self.shared_experts = None # type: ignore
CustomDeepseekV2MoE.top_k = config.num_experts_per_tok CustomDeepseekV2MoE.top_k = config.num_experts_per_tok
self.dp_size = get_dp_group().world_size self.dp_size = get_dp_group().world_size
@ -216,12 +293,6 @@ class CustomDeepseekV2MoE(nn.Module):
self.params_dtype = torch.get_default_dtype() self.params_dtype = torch.get_default_dtype()
ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
# NOTE: multistream only effective when `VLLM_ENABLE_MC2` is on
self.enable_multistream_shared_expert = \
ascend_config.torchair_graph_config.enable_multistream_shared_expert and VLLM_ENABLE_MC2
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
@ -240,12 +311,10 @@ class CustomDeepseekV2MoE(nn.Module):
enable_force_load_balance = False enable_force_load_balance = False
if hasattr(attn_metadata, 'with_prefill_across_dp'): if hasattr(attn_metadata, 'with_prefill_across_dp'):
is_prefill = is_prefill or attn_metadata.with_prefill_across_dp is_prefill = is_prefill or attn_metadata.with_prefill_across_dp
num_tokens, hidden_size = hidden_states.shape num_tokens, hidden_size = hidden_states.shape
old_hidden_states = hidden_states
multistream = self.enable_multistream_shared_expert and not is_prefill use_separated_shared_experts = (self.shared_experts is not None
and not self.enable_multistream_moe)
old_hidden_states = hidden_states.clone()
if self.tp_size > 1: if self.tp_size > 1:
if (VLLM_ENABLE_MC2 if (VLLM_ENABLE_MC2
@ -262,25 +331,22 @@ class CustomDeepseekV2MoE(nn.Module):
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
kwargs = {} experts_hidden_states = self.experts(
if multistream:
kwargs.update({
"shared_experts": self.shared_experts,
"shared_hidden_states": old_hidden_states
})
hidden_states = self.experts(
hidden_states=hidden_states, hidden_states=hidden_states,
router_logits=router_logits, router_logits=router_logits,
is_prefill=is_prefill, is_prefill=is_prefill,
top_k=CustomDeepseekV2MoE.top_k, top_k=CustomDeepseekV2MoE.top_k,
enable_force_load_balance=enable_force_load_balance, enable_force_load_balance=enable_force_load_balance,
**kwargs) shared_experts=(self.shared_experts
if not use_separated_shared_experts else None),
)
if multistream: if not isinstance(experts_hidden_states, tuple):
hidden_states, shared_output = hidden_states hidden_states = experts_hidden_states * self.routed_scaling_factor
else:
hidden_states = hidden_states * self.routed_scaling_factor hidden_states = (
experts_hidden_states[0] * self.routed_scaling_factor +
experts_hidden_states[1])
if self.tp_size > 1: if self.tp_size > 1:
if (VLLM_ENABLE_MC2 if (VLLM_ENABLE_MC2
@ -294,12 +360,9 @@ class CustomDeepseekV2MoE(nn.Module):
else: else:
hidden_states = tensor_model_parallel_all_reduce(hidden_states) hidden_states = tensor_model_parallel_all_reduce(hidden_states)
if self.n_shared_experts is not None: if use_separated_shared_experts:
if not multistream: hidden_states = hidden_states + self.shared_experts(
shared_output = self.shared_experts(old_hidden_states) old_hidden_states)
if shared_output is not None:
hidden_states = hidden_states + shared_output
return hidden_states.view(num_tokens, hidden_size) return hidden_states.view(num_tokens, hidden_size)

View File

@ -16,7 +16,7 @@
# Adapted from vllm/tests/kernels/test_moe.py # Adapted from vllm/tests/kernels/test_moe.py
import os import os
from typing import Callable, List, Optional from typing import Any, Callable, List, Optional, Tuple, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -36,6 +36,7 @@ import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
from vllm_ascend.utils import npu_stream_switch, npu_wait_tensor
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2 VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
USING_LCCL_COM: bool = envs_ascend.USING_LCCL_COM USING_LCCL_COM: bool = envs_ascend.USING_LCCL_COM
@ -106,15 +107,17 @@ def process_topk_ids(topk_ids: torch.Tensor, expert_num: int, ep_size: int,
return topk_ids_pad, unpad_indices return topk_ids_pad, unpad_indices
def fused_experts_with_mc2(hidden_states: torch.Tensor, def fused_experts_with_mc2(
w1: torch.Tensor, hidden_states: torch.Tensor,
w2: torch.Tensor, w1: torch.Tensor,
topk_weights: torch.Tensor, w2: torch.Tensor,
topk_ids: torch.Tensor, topk_weights: torch.Tensor,
top_k: int, topk_ids: torch.Tensor,
expert_map: torch.Tensor = None, top_k: int,
moe_all_to_all_group_name: Optional[str] = None, expert_map: torch.Tensor = None,
**kwargs) -> torch.Tensor: moe_all_to_all_group_name: Optional[str] = None,
shared_experts: Optional[Any] = None
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
global_bs = 0 global_bs = 0
moe_expert_num = len(expert_map) moe_expert_num = len(expert_map)
kwargs_mc2 = { kwargs_mc2 = {
@ -154,6 +157,13 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor,
expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[ expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[
0:5] 0:5]
if shared_experts is not None:
with npu_stream_switch("moe_secondary", 0):
npu_wait_tensor(hidden_states, topk_weights)
shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states)
npu_wait_tensor(shared_gate_up, expand_x)
shared_act = shared_experts.act_fn(shared_gate_up)
w1 = w1.transpose(1, 2) w1 = w1.transpose(1, 2)
group_list = expert_token_nums.to(torch.int64) group_list = expert_token_nums.to(torch.int64)
@ -210,7 +220,13 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor,
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2) hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2)
return hidden_states if shared_experts is None:
return hidden_states
else:
with npu_stream_switch("moe_secondary", 0):
npu_wait_tensor(shared_act, down_out_list)
shared_hidden_states, _ = shared_experts.down_proj(shared_act)
return hidden_states, shared_hidden_states
def apply_mlp(hidden_states_wrapper: List[torch.Tensor], def apply_mlp(hidden_states_wrapper: List[torch.Tensor],
@ -875,6 +891,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: Optional[torch.Tensor] = None,
is_prefill: bool = False, is_prefill: bool = False,
enable_force_load_balance: bool = False, enable_force_load_balance: bool = False,
shared_experts: Optional[Any] = None,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
@ -924,7 +941,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
top_k=top_k, top_k=top_k,
expert_map=expert_map, expert_map=expert_map,
moe_all_to_all_group_name=self.moe_all_to_all_group_name, moe_all_to_all_group_name=self.moe_all_to_all_group_name,
**kwargs) shared_experts=shared_experts)
elif self.torchair_graph_enabled or get_ep_group().world_size == 1: elif self.torchair_graph_enabled or get_ep_group().world_size == 1:
return fused_experts(hidden_states=x, return fused_experts(hidden_states=x,
w1=layer.w13_weight, w1=layer.w13_weight,
@ -1053,9 +1070,6 @@ class AscendFusedMoE(FusedMoE):
self.moe_parallel_config.ep_rank = get_ep_group().rank_in_group self.moe_parallel_config.ep_rank = get_ep_group().rank_in_group
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
# NOTE: multistream only effective when `VLLM_ENABLE_MC2` is on
self.enable_multistream_shared_expert = \
ascend_config.torchair_graph_config.enable_multistream_shared_expert and VLLM_ENABLE_MC2
if self.scoring_func != "softmax" and not self.use_grouped_topk: if self.scoring_func != "softmax" and not self.use_grouped_topk:
raise ValueError("Only softmax scoring function is supported for " raise ValueError("Only softmax scoring function is supported for "
@ -1102,8 +1116,8 @@ class AscendFusedMoE(FusedMoE):
router_logits: torch.Tensor, router_logits: torch.Tensor,
is_prefill: bool, is_prefill: bool,
enable_force_load_balance: bool = False, enable_force_load_balance: bool = False,
top_k=None, top_k: Optional[int] = None,
**kwargs): shared_experts: Optional[Any] = None):
assert self.quant_method is not None assert self.quant_method is not None
if top_k: if top_k:
@ -1132,7 +1146,7 @@ class AscendFusedMoE(FusedMoE):
hidden_states, router_logits) hidden_states, router_logits)
# Matrix multiply. # Matrix multiply.
hidden_states = self.quant_method.apply( e_hidden_states = self.quant_method.apply(
layer=self, layer=self,
x=hidden_states, x=hidden_states,
router_logits=router_logits, router_logits=router_logits,
@ -1150,36 +1164,39 @@ class AscendFusedMoE(FusedMoE):
enable_force_load_balance=enable_force_load_balance, enable_force_load_balance=enable_force_load_balance,
log2phy=self.log2phy, log2phy=self.log2phy,
global_redundant_expert_num=self.global_redundant_expert_num, global_redundant_expert_num=self.global_redundant_expert_num,
**kwargs) shared_experts=shared_experts,
)
if self.enable_multistream_shared_expert and not is_prefill: if shared_experts is not None:
hidden_states, shared_output = hidden_states # Provide dummy implementation of "non-separated" shared experts.
if not isinstance(e_hidden_states, tuple):
return e_hidden_states, shared_experts(hidden_states)
else:
return e_hidden_states
if self.dp_size > 1: if self.dp_size > 1:
if VLLM_ENABLE_MC2 and not is_prefill: if VLLM_ENABLE_MC2 and not is_prefill:
... ...
elif self.torchair_graph_enabled: elif self.torchair_graph_enabled:
if USING_LCCL_COM: # type: ignore if USING_LCCL_COM: # type: ignore
hidden_states = dist._functional_collectives.reduce_scatter_tensor( e_hidden_states = dist._functional_collectives.reduce_scatter_tensor(
hidden_states, e_hidden_states,
"sum", "sum",
scatter_dim=0, scatter_dim=0,
group=get_dp_group().device_group) group=get_dp_group().device_group)
elif self.torchair_graph_enabled and not is_prefill: elif self.torchair_graph_enabled and not is_prefill:
hidden_states = dist._functional_collectives.reduce_scatter_tensor( e_hidden_states = dist._functional_collectives.reduce_scatter_tensor(
hidden_states, e_hidden_states,
"sum", "sum",
scatter_dim=0, scatter_dim=0,
group=get_dp_group().device_group) group=get_dp_group().device_group)
else: else:
hidden_states = get_ep_group().combine(hidden_states) e_hidden_states = get_ep_group().combine(e_hidden_states)
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
hidden_states = tensor_model_parallel_all_reduce(hidden_states) e_hidden_states = tensor_model_parallel_all_reduce(e_hidden_states)
if self.enable_multistream_shared_expert and not is_prefill: return e_hidden_states
return hidden_states, shared_output
return hidden_states
# ----------------------------------------- TBO-related -------------------------------------------- # ----------------------------------------- TBO-related --------------------------------------------

View File

@ -15,19 +15,19 @@
# limitations under the License. # limitations under the License.
# #
from typing import Any, Callable, Dict, Optional from typing import Any, Callable, Dict, Optional, Tuple, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch_npu import torch_npu
import torchair as tng # type: ignore from vllm.distributed import GroupCoordinator
from vllm.distributed import GroupCoordinator, tensor_model_parallel_all_reduce
import vllm_ascend.envs as envs_ascend import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.parallel_state import get_ep_group from vllm_ascend.distributed.parallel_state import get_ep_group
from vllm_ascend.ops.fused_moe import select_experts from vllm_ascend.ops.fused_moe import select_experts
from vllm_ascend.utils import dispose_tensor from vllm_ascend.utils import (dispose_tensor, npu_stream_switch,
npu_wait_tensor)
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2 VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
@ -39,8 +39,7 @@ def apply_mlp(hidden_states: torch.Tensor,
w2_scale: torch.Tensor, w2_scale: torch.Tensor,
group_list: torch.Tensor, group_list: torch.Tensor,
dynamic_scale: torch.Tensor = None, dynamic_scale: torch.Tensor = None,
group_list_type: int = 1, group_list_type: int = 1) -> torch.Tensor:
**kwargs) -> torch.Tensor:
""" """
apply MLP: gate_up_proj -> swiglu -> down_proj apply MLP: gate_up_proj -> swiglu -> down_proj
@ -74,23 +73,6 @@ def apply_mlp(hidden_states: torch.Tensor,
else: else:
pertoken_scale = dynamic_scale pertoken_scale = dynamic_scale
shared_experts = kwargs.get('shared_experts', None)
if shared_experts:
shared_gate_up = kwargs.get('shared_gate_up', None)
shared_dynamic_scale = kwargs.get('shared_dynamic_scale', None)
with tng.scope.npu_stream_switch('cv'):
tng.scope.npu_wait_tensor(shared_gate_up, hidden_states)
shared_x, shared_dynamic_scale = torch_npu.npu_dequant_swiglu_quant(
x=shared_gate_up,
weight_scale=shared_experts.gate_up_proj.weight_scale_fp32,
activation_scale=shared_dynamic_scale,
bias=None,
quant_scale=None,
quant_offset=None,
group_index=None,
activate_left=True,
quant_mode=1)
# gmm1: gate_up_proj # gmm1: gate_up_proj
hidden_states = torch_npu.npu_grouped_matmul( hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states], x=[hidden_states],
@ -120,36 +102,24 @@ def apply_mlp(hidden_states: torch.Tensor,
group_list=group_list, group_list=group_list,
output_dtype=w2_scale.dtype)[0] output_dtype=w2_scale.dtype)[0]
if shared_experts:
with tng.scope.npu_stream_switch('cv'):
tng.scope.npu_wait_tensor(shared_x, hidden_states)
shared_output = torch_npu.npu_quant_matmul(
shared_x,
shared_experts.down_proj.weight,
shared_experts.down_proj.weight_scale,
pertoken_scale=shared_dynamic_scale,
output_dtype=torch.bfloat16,
)
if shared_experts.down_proj.reduce_results and shared_experts.down_proj.tp_size > 1:
shared_output = tensor_model_parallel_all_reduce(shared_output)
if shared_experts:
return hidden_states, shared_output
return hidden_states return hidden_states
def fused_experts_with_mc2(hidden_states: torch.Tensor, def fused_experts_with_mc2(
w1: torch.Tensor, hidden_states: torch.Tensor,
w2: torch.Tensor, w1: torch.Tensor,
w1_scale: torch.Tensor, w2: torch.Tensor,
w2_scale: torch.Tensor, w1_scale: torch.Tensor,
topk_weights: torch.Tensor, w2_scale: torch.Tensor,
topk_ids: torch.Tensor, topk_weights: torch.Tensor,
top_k: int, topk_ids: torch.Tensor,
expert_map: torch.Tensor = None, top_k: int,
moe_all_to_all_group_name: str = "", expert_map: torch.Tensor = None,
log2phy: torch.Tensor = None, moe_all_to_all_group_name: str = "",
global_redundant_expert_num: int = 0, log2phy: torch.Tensor = None,
**kwargs) -> torch.Tensor: global_redundant_expert_num: int = 0,
shared_experts: Optional[Any] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if log2phy: if log2phy:
topk_ids = log2phy[topk_ids] topk_ids = log2phy[topk_ids]
global_bs = 0 global_bs = 0
@ -188,31 +158,17 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor,
} }
kwargs_mc2.update(stage1_kwargs) kwargs_mc2.update(stage1_kwargs)
shared_experts = kwargs.get('shared_experts', None)
if shared_experts:
shared_hidden_states = kwargs.get('shared_hidden_states', None)
with tng.scope.npu_stream_switch('cv'):
tng.scope.npu_wait_tensor(shared_hidden_states, hidden_states)
shared_x, shared_dynamic_scale = torch_npu.npu_dynamic_quant(
shared_hidden_states)
shared_gate_up = torch_npu.npu_quant_matmul(
shared_x,
shared_experts.gate_up_proj.weight,
shared_experts.gate_up_proj.weight_scale,
output_dtype=torch.int32,
)
kwargs.update({
"shared_gate_up": shared_gate_up,
"shared_dynamic_scale": shared_dynamic_scale,
})
output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2) output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2)
# comm_stream.wait_stream(torch.npu.current_stream()) # comm_stream.wait_stream(torch.npu.current_stream())
expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[ expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[
0:5] 0:5]
if quant_mode == 0: if shared_experts is not None:
dynamic_scale = None with npu_stream_switch("moe_secondary", 0):
npu_wait_tensor(hidden_states, topk_weights)
shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states)
npu_wait_tensor(shared_gate_up[0], expand_x)
shared_act = shared_experts.act_fn(shared_gate_up)
# `expand_x` will be disposed in the `apply_mlp` function # `expand_x` will be disposed in the `apply_mlp` function
down_out_list = apply_mlp(expand_x, down_out_list = apply_mlp(expand_x,
@ -221,12 +177,7 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor,
w2, w2,
w2_scale, w2_scale,
expert_token_nums, expert_token_nums,
dynamic_scale=dynamic_scale, dynamic_scale=dynamic_scale)
**kwargs)
multi_stream = isinstance(down_out_list, tuple)
if multi_stream:
down_out_list, shared_output = down_out_list
# moeCombine # moeCombine
kwargs_mc2 = { kwargs_mc2 = {
@ -257,9 +208,13 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor,
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2) hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2)
if multi_stream: if shared_experts is None:
return hidden_states
else:
with npu_stream_switch("moe_secondary", 0):
npu_wait_tensor(shared_act[0], down_out_list)
shared_output, _ = shared_experts.down_proj(shared_act)
return hidden_states, shared_output return hidden_states, shared_output
return hidden_states
# currently expert parallelism implemented with all2all # currently expert parallelism implemented with all2all
@ -541,21 +496,33 @@ class AscendW8A8DynamicLinearMethod:
@staticmethod @staticmethod
def apply( def apply(
layer: torch.nn.Module, layer: torch.nn.Module,
x: torch.Tensor, x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
tp_rank: Optional[int] = 0, tp_rank: Optional[int] = 0,
) -> torch.Tensor: ) -> torch.Tensor:
original_dtype = x.dtype config = getattr(layer, "_ascend_quant_config", {})
# use ATB quantize if not isinstance(x, tuple):
quant_out, dynamic_scale = torch_npu.npu_dynamic_quant(x) output_dtype = config.get("output_dtype", x.dtype)
return torch_npu.npu_quant_matmul( quantized_x, dynamic_scale = torch_npu.npu_dynamic_quant(x)
quant_out, else:
assert "output_dtype" in config.keys(), (
f"DynamicLinearMethod needs explicitly specified `output_dtype`"
f"for pre-quantized input, got config [{config}]")
output_dtype = config["output_dtype"]
quantized_x, dynamic_scale = x
pertoken_scale = (dynamic_scale
if config.get("pertoken_scale", True) else None)
output = torch_npu.npu_quant_matmul(
quantized_x,
layer.weight, layer.weight,
layer.weight_scale, layer.weight_scale,
pertoken_scale=dynamic_scale, pertoken_scale=pertoken_scale,
bias=bias, bias=bias,
output_dtype=original_dtype, output_dtype=output_dtype,
) )
return ((output, dynamic_scale)
if config.get("return_scale", False) else output)
def process_weights_after_loading(self, layer): def process_weights_after_loading(self, layer):
if self.transpose_weight: if self.transpose_weight:
@ -650,6 +617,7 @@ class AscendW8A8DynamicFusedMoEMethod:
enable_force_load_balance: bool = True, enable_force_load_balance: bool = True,
log2phy: torch.Tensor = None, log2phy: torch.Tensor = None,
global_redundant_expert_num: int = 0, global_redundant_expert_num: int = 0,
shared_experts: Optional[Any] = None,
**kwargs, **kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
assert router_logits.shape[ assert router_logits.shape[
@ -706,7 +674,7 @@ class AscendW8A8DynamicFusedMoEMethod:
moe_all_to_all_group_name=self.moe_all_to_all_group_name, moe_all_to_all_group_name=self.moe_all_to_all_group_name,
log2phy=log2phy, log2phy=log2phy,
global_redundant_expert_num=global_redundant_expert_num, global_redundant_expert_num=global_redundant_expert_num,
**kwargs) shared_experts=shared_experts)
elif self.torchair_graph_enabled or self.ep_group.world_size == 1: elif self.torchair_graph_enabled or self.ep_group.world_size == 1:
return fused_experts(hidden_states=x, return fused_experts(hidden_states=x,
w1=layer.w13_weight, w1=layer.w13_weight,

View File

@ -19,17 +19,26 @@
import atexit import atexit
import math import math
from contextlib import contextmanager from contextlib import contextmanager, nullcontext
from threading import Lock from threading import Lock
from typing import TYPE_CHECKING, List, Tuple from typing import TYPE_CHECKING, List, Tuple
import torch import torch
import torchair # type: ignore[import] # noqa: F401
from packaging.version import InvalidVersion, Version from packaging.version import InvalidVersion, Version
from torch_npu.npu.streams import Event from torch_npu.npu.streams import Event
from vllm.logger import logger from vllm.logger import logger
import vllm_ascend.envs as envs import vllm_ascend.envs as envs
try:
# Recent release of torchair has moved these ops to `.scope`.
from torchair.scope import npu_stream_switch as _npu_stream_switch
from torchair.scope import npu_wait_tensor as _npu_wait_tensor
except ImportError:
from torchair.ops import NpuStreamSwitch as _npu_stream_switch
from torchair.ops import npu_wait_tensor as _npu_wait_tensor
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import VllmConfig from vllm.config import VllmConfig
else: else:
@ -227,3 +236,14 @@ class ProfileExecuteDuration:
durations[tag] = observe_start.elapsed_time(observe_end) durations[tag] = observe_start.elapsed_time(observe_end)
return durations return durations
def npu_stream_switch(tag: str, priority: int, *, enabled: bool = True):
return _npu_stream_switch(tag, priority) if enabled else nullcontext()
def npu_wait_tensor(self: torch.Tensor,
dependency: torch.Tensor,
*,
enabled: bool = True):
return _npu_wait_tensor(self, dependency) if enabled else self