mirror of
https://github.com/vllm-project/vllm-ascend.git
synced 2025-10-20 13:43:53 +08:00
Support multistream of shared experts in FusedMoE (#997)
Contains on #1111 for completeness. <!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? Implement multi-stream parallelism for MoE layers with shared experts, where computation of shared experts will be overlapped with expert token dispatch and combine. Also, when multi-stream is enabled, weights of shared experts will be force to replicate across all cards, regardless of any tensor parallelism configurations, to avoid AllReduce operations. With the expected overlaping being: ``` | shared gate_up | shared act | | shared down | | dispatch | routed gate_up, act, down | combine | ``` <!-- - Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue. If possible, please consider writing useful notes for better and faster reviews in your PR. - Please clarify why the changes are needed. For instance, the use case and bug description. - Fixes # --> ### Does this PR introduce _any_ user-facing change? No. <!-- Note that it means *any* user-facing change including all aspects such as API, interface or other behavior changes. Documentation-only updates are not considered user-facing changes. --> ### How was this patch tested? Tested on 1x16 910 node, with tailored 2 layer DSKv2. <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> --------- Signed-off-by: sdmyzlp <lrwei2@petalmail.com>
This commit is contained in:
2
.github/workflows/vllm_ascend_test.yaml
vendored
2
.github/workflows/vllm_ascend_test.yaml
vendored
@ -188,6 +188,7 @@ jobs:
|
||||
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_QwQ
|
||||
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek
|
||||
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_topk
|
||||
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W8A8
|
||||
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/ --ignore=tests/multicard/test_ilama_lora_tp2.py --ignore=tests/multicard/test_offline_inference_distributed.py
|
||||
fi
|
||||
|
||||
@ -218,5 +219,6 @@ jobs:
|
||||
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_QwQ
|
||||
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek
|
||||
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_topk
|
||||
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W8A8
|
||||
VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/ --ignore=tests/multicard/test_ilama_lora_tp2.py --ignore=tests/multicard/test_offline_inference_distributed.py
|
||||
fi
|
||||
|
@ -39,11 +39,11 @@ The details of each config option are as follows:
|
||||
| Name | Type | Default | Description |
|
||||
| ---- | ---- | ------- | ----------- |
|
||||
| `enabled` | bool | `False` | Whether to enable torchair graph mode |
|
||||
| `enable_multistream_moe`| bool | `False` | Whether to enable multistream shared expert |
|
||||
| `enable_view_optimize` | bool | `True` | Whether to enable torchair view optimization |
|
||||
| `use_cached_graph` | bool | `False` | Whether to use cached graph |
|
||||
| `graph_batch_sizes` | list[int] | `[]` | The batch size for torchair graph cache |
|
||||
| `graph_batch_sizes_init` | bool | `False` | Init graph batch size dynamically if `graph_batch_sizes` is empty |
|
||||
| `enable_multistream_shared_expert`| bool | `False` | Whether to enable multistream shared expert |
|
||||
|
||||
**ascend_scheduler_config**
|
||||
|
||||
@ -64,7 +64,7 @@ A full example of additional configuration is as follows:
|
||||
"use_cached_graph": true,
|
||||
"graph_batch_sizes": [1, 2, 4, 8],
|
||||
"graph_batch_sizes_init": false,
|
||||
"enable_multistream_shared_expert": false
|
||||
"enable_multistream_moe": false
|
||||
},
|
||||
"ascend_scheduler_config": {
|
||||
"enabled": true,
|
||||
|
3
mypy.ini
3
mypy.ini
@ -6,6 +6,9 @@ warn_unused_configs = True
|
||||
[mypy-torch_npu.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-torchair.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-transformers.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
|
@ -23,7 +23,7 @@ Run `pytest tests/test_offline_inference.py`.
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import vllm # noqa: F401
|
||||
from modelscope import snapshot_download # type: ignore
|
||||
from vllm import SamplingParams
|
||||
|
||||
from tests.conftest import VllmRunner
|
||||
@ -95,3 +95,20 @@ def test_models_distributed_DeepSeek_dbo():
|
||||
distributed_executor_backend="mp",
|
||||
) as vllm_model:
|
||||
vllm_model.generate(example_prompts, sampling_params)
|
||||
|
||||
|
||||
def test_models_distributed_DeepSeek_W8A8():
|
||||
example_prompts = [
|
||||
"Hello, my name is",
|
||||
]
|
||||
max_tokens = 5
|
||||
|
||||
with VllmRunner(
|
||||
snapshot_download("vllm-ascend/DeepSeek-V2-Lite-W8A8"),
|
||||
max_model_len=8192,
|
||||
enforce_eager=True,
|
||||
dtype="auto",
|
||||
tensor_parallel_size=4,
|
||||
quantization="ascend",
|
||||
) as vllm_model:
|
||||
vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
|
@ -58,7 +58,7 @@ def test_run_with_ascend_config():
|
||||
"use_cached_graph": True,
|
||||
"graph_batch_sizes": [1, 2, 4, 8],
|
||||
"graph_batch_sizes_init": False,
|
||||
"enable_multistream_shared_expert": True,
|
||||
"enable_multistream_moe": True,
|
||||
},
|
||||
"ascend_scheduler_config": {
|
||||
"enabled": True,
|
||||
@ -79,7 +79,7 @@ def test_run_with_ascend_config():
|
||||
1, 2, 4, 8
|
||||
]
|
||||
assert not ascend_config.torchair_graph_config.graph_batch_sizes_init
|
||||
assert ascend_config.torchair_graph_config.enable_multistream_shared_expert
|
||||
assert ascend_config.torchair_graph_config.enable_multistream_moe
|
||||
assert ascend_config.ascend_scheduler_config.enabled
|
||||
assert ascend_config.ascend_scheduler_config.enable_chunked_prefill
|
||||
assert ascend_config.expert_tensor_parallel_size == 1
|
||||
|
@ -54,8 +54,8 @@ class TorchairGraphConfig:
|
||||
"graph_batch_sizes", [])
|
||||
self.graph_batch_sizes_init = torchair_graph_config.get(
|
||||
"graph_batch_sizes_init", False)
|
||||
self.enable_multistream_shared_expert = torchair_graph_config.get(
|
||||
"enable_multistream_shared_expert", False)
|
||||
self.enable_multistream_moe = torchair_graph_config.get(
|
||||
"enable_multistream_moe", False)
|
||||
self.enable_view_optimize = torchair_graph_config.get(
|
||||
"enable_view_optimize", True)
|
||||
|
||||
|
@ -29,7 +29,7 @@ from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch_npu
|
||||
import torch_npu # noqa: F401
|
||||
import vllm.envs as envs
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
@ -40,13 +40,10 @@ from vllm.distributed import (get_pp_group,
|
||||
get_tp_group, tensor_model_parallel_all_reduce)
|
||||
from vllm.distributed.parallel_state import get_dp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear,
|
||||
UnquantizedLinearMethod)
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
@ -67,6 +64,7 @@ from vllm.sequence import IntermediateTensors
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.models.deepseek_v2 import CustomDeepseekV2MLP
|
||||
from vllm_ascend.multistream.base import MSEventKey
|
||||
from vllm_ascend.multistream.context import (
|
||||
advance_step_multistream_layer_context, get_multistream_comm_context,
|
||||
@ -78,117 +76,17 @@ from vllm_ascend.multistream.metadata import (MultiStreamConfig,
|
||||
make_multistream_metadata_ds)
|
||||
from vllm_ascend.multistream.ms_split import compute_split_seq_index
|
||||
from vllm_ascend.ops.fused_moe import AscendFusedMoE
|
||||
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
|
||||
from vllm_ascend.utils import dispose_tensor
|
||||
|
||||
VLLM_ASCEND_ENABLE_DBO: bool = envs_ascend.VLLM_ASCEND_ENABLE_DBO
|
||||
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
|
||||
|
||||
|
||||
class CustomDeepseekDBOMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
reduce_results: bool = True,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj")
|
||||
self.down_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
reduce_results=reduce_results,
|
||||
prefix=f"{prefix}.down_proj")
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
# NOTE: `torch_npu.npu_dequant_swiglu_quant` can only be enabled in dynamic quant
|
||||
self.is_dynamic_quant = not isinstance(
|
||||
self.gate_up_proj.quant_method,
|
||||
UnquantizedLinearMethod) and isinstance(
|
||||
self.gate_up_proj.quant_method.quant_method,
|
||||
AscendW8A8DynamicLinearMethod)
|
||||
|
||||
def forward(self, x):
|
||||
if self.is_dynamic_quant:
|
||||
x, dynamic_scale = torch_npu.npu_dynamic_quant(x)
|
||||
x = torch_npu.npu_quant_matmul(
|
||||
x,
|
||||
self.gate_up_proj.weight,
|
||||
self.gate_up_proj.weight_scale,
|
||||
output_dtype=torch.int32,
|
||||
)
|
||||
x, dynamic_scale = torch_npu.npu_dequant_swiglu_quant(
|
||||
x=x,
|
||||
weight_scale=self.gate_up_proj.weight_scale_fp32,
|
||||
activation_scale=dynamic_scale,
|
||||
bias=None,
|
||||
quant_scale=None,
|
||||
quant_offset=None,
|
||||
group_index=None,
|
||||
activate_left=True,
|
||||
quant_mode=1)
|
||||
x = torch_npu.npu_quant_matmul(
|
||||
x,
|
||||
self.down_proj.weight,
|
||||
self.down_proj.weight_scale,
|
||||
pertoken_scale=dynamic_scale,
|
||||
output_dtype=torch.bfloat16,
|
||||
)
|
||||
if self.down_proj.reduce_results and self.down_proj.tp_size > 1:
|
||||
x = tensor_model_parallel_all_reduce(x)
|
||||
return x
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
x = self.act_fn(gate_up)
|
||||
x, _ = self.down_proj(x)
|
||||
return x
|
||||
class CustomDeepseekDBOMLP(CustomDeepseekV2MLP):
|
||||
|
||||
def _forward_ms_mlp(self, x):
|
||||
current_ms_metadata = get_multistream_comm_context()
|
||||
assert current_ms_metadata is not None
|
||||
if self.is_dynamic_quant:
|
||||
x, dynamic_scale = torch_npu.npu_dynamic_quant(x)
|
||||
x = torch_npu.npu_quant_matmul(
|
||||
x,
|
||||
self.gate_up_proj.weight,
|
||||
self.gate_up_proj.weight_scale,
|
||||
output_dtype=torch.int32,
|
||||
)
|
||||
x, dynamic_scale = torch_npu.npu_dequant_swiglu_quant(
|
||||
x=x,
|
||||
weight_scale=self.gate_up_proj.weight_scale_fp32,
|
||||
activation_scale=dynamic_scale,
|
||||
bias=None,
|
||||
quant_scale=None,
|
||||
quant_offset=None,
|
||||
group_index=None,
|
||||
activate_left=True,
|
||||
quant_mode=1)
|
||||
x = torch_npu.npu_quant_matmul(
|
||||
x,
|
||||
self.down_proj.weight,
|
||||
self.down_proj.weight_scale,
|
||||
pertoken_scale=dynamic_scale,
|
||||
output_dtype=torch.bfloat16,
|
||||
)
|
||||
if self.down_proj.reduce_results and self.down_proj.tp_size > 1:
|
||||
current_ms_metadata.before_comm_event.record()
|
||||
with torch.npu.stream(current_ms_metadata.comm_stream):
|
||||
current_ms_metadata.before_comm_event.wait()
|
||||
x = tensor_model_parallel_all_reduce(x)
|
||||
current_ms_metadata.after_comm_event.record()
|
||||
return x
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
x = self.act_fn(gate_up)
|
||||
current_ms_metadata.before_comm_event.record()
|
||||
|
@ -25,7 +25,7 @@
|
||||
# # vllm-project/vllm/vllm/model_executor/models/deepseek_v2.py
|
||||
# """Inference-only DeepseekV2/DeepseekV3 model."""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -69,12 +69,73 @@ import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.distributed.parallel_state import get_ep_group
|
||||
from vllm_ascend.ops.fused_moe import AscendFusedMoE
|
||||
from vllm_ascend.quantization.quant_config import AscendLinearMethod
|
||||
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
|
||||
from vllm_ascend.utils import dispose_tensor
|
||||
|
||||
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
|
||||
|
||||
|
||||
class CustomDeepseekV2SiluAndMul(SiluAndMul):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
weight_scale: Optional[Callable[[], torch.Tensor]] = None):
|
||||
super().__init__()
|
||||
self.weight_scale = weight_scale
|
||||
|
||||
def forward_oot(self, x: Union[torch.Tensor, Tuple[torch.Tensor,
|
||||
torch.Tensor]]):
|
||||
if isinstance(x, tuple):
|
||||
assert self.weight_scale is not None
|
||||
# For AscendW8A8DynamicLinearMethod:
|
||||
# a dynamic scale is passed along with the quantized value.
|
||||
quantized_x, dynamic_scale = x
|
||||
return torch_npu.npu_dequant_swiglu_quant(
|
||||
x=quantized_x,
|
||||
weight_scale=self.weight_scale(),
|
||||
activation_scale=dynamic_scale,
|
||||
activate_left=True,
|
||||
quant_mode=1)
|
||||
else:
|
||||
return super().forward_oot(x)
|
||||
|
||||
|
||||
class CustomDeepseekV2MergedReplicatedLinear(ReplicatedLinear):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_sizes: list[int],
|
||||
bias: bool = True,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
self.output_sizes = output_sizes
|
||||
super().__init__(input_size,
|
||||
sum(output_sizes),
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix)
|
||||
|
||||
def weight_loader(self, param: torch.nn.Parameter,
|
||||
loaded_weight: torch.Tensor, loaded_shard_id: int):
|
||||
# With no support for GGUF format yet.
|
||||
assert not getattr(param, "is_gguf_weight", False)
|
||||
assert not getattr(param, "is_gguf_weight_type", False)
|
||||
|
||||
assert loaded_shard_id < len(self.output_sizes)
|
||||
shard_offset = sum(self.output_sizes[:loaded_shard_id])
|
||||
shard_size = self.output_sizes[loaded_shard_id]
|
||||
shard = param.data.narrow(param.output_dim, shard_offset, shard_size)
|
||||
|
||||
assert shard.size() == loaded_weight.size(), (
|
||||
f"Tried to load weights of size {loaded_weight.size()}"
|
||||
f"to a parameter shard of id {loaded_shard_id} size {shard.size()}"
|
||||
)
|
||||
shard.copy_(loaded_weight)
|
||||
|
||||
|
||||
class CustomDeepseekV2MLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
@ -84,9 +145,11 @@ class CustomDeepseekV2MLP(nn.Module):
|
||||
hidden_act: str,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
reduce_results: bool = True,
|
||||
force_replicate: bool = False,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if not force_replicate:
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2,
|
||||
bias=False,
|
||||
@ -98,47 +161,52 @@ class CustomDeepseekV2MLP(nn.Module):
|
||||
quant_config=quant_config,
|
||||
reduce_results=reduce_results,
|
||||
prefix=f"{prefix}.down_proj")
|
||||
else:
|
||||
self.gate_up_proj = CustomDeepseekV2MergedReplicatedLinear(
|
||||
hidden_size, [intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj")
|
||||
self.down_proj = ReplicatedLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.down_proj")
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
# NOTE: `torch_npu.npu_dequant_swiglu_quant` can only be enabled in dynamic quant
|
||||
self.is_dynamic_quant = not isinstance(
|
||||
self.gate_up_proj.quant_method,
|
||||
UnquantizedLinearMethod) and isinstance(
|
||||
self.gate_up_proj.quant_method.quant_method,
|
||||
AscendW8A8DynamicLinearMethod)
|
||||
quant_method = self.gate_up_proj.quant_method
|
||||
if isinstance(quant_method, UnquantizedLinearMethod):
|
||||
self.act_fn = CustomDeepseekV2SiluAndMul()
|
||||
elif (isinstance(quant_method, AscendLinearMethod) and isinstance(
|
||||
quant_method.quant_method, AscendW8A8DynamicLinearMethod)):
|
||||
# TODO(sdmyzlp): Currently preserved as before:
|
||||
# 1. The only quantization supported for silu is W8A8Dynamic
|
||||
# 2. Output dtype of gate_up/down is fixed to be int32/bfloat16
|
||||
#
|
||||
# Maybe one can implement a better and more general configuration
|
||||
# scheme, e.g. by somehow passing around the tweaked `quant_config`
|
||||
self.act_fn = CustomDeepseekV2SiluAndMul(
|
||||
# Use lazy binding, for `weight_scale_fp32` is accessible
|
||||
# only after `process_weights_after_loading`.
|
||||
weight_scale=lambda: self.gate_up_proj.weight_scale_fp32)
|
||||
# To be consumed by AscendW8A8DynamicLinearMethod.apply()
|
||||
self.gate_up_proj._ascend_quant_config = {
|
||||
"output_dtype": torch.int32,
|
||||
"pertoken_scale": False,
|
||||
"return_scale": True,
|
||||
}
|
||||
self.down_proj._ascend_quant_config = {
|
||||
"output_dtype": torch.bfloat16,
|
||||
"pertoken_scale": True,
|
||||
"return_scale": False,
|
||||
}
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Quantization with [{type(quant_method)}] is NOT supported")
|
||||
|
||||
def forward(self, x):
|
||||
if self.is_dynamic_quant:
|
||||
x, dynamic_scale = torch_npu.npu_dynamic_quant(x)
|
||||
x = torch_npu.npu_quant_matmul(
|
||||
x,
|
||||
self.gate_up_proj.weight,
|
||||
self.gate_up_proj.weight_scale,
|
||||
output_dtype=torch.int32,
|
||||
)
|
||||
x, dynamic_scale = torch_npu.npu_dequant_swiglu_quant(
|
||||
x=x,
|
||||
weight_scale=self.gate_up_proj.weight_scale_fp32,
|
||||
activation_scale=dynamic_scale,
|
||||
bias=None,
|
||||
quant_scale=None,
|
||||
quant_offset=None,
|
||||
group_index=None,
|
||||
activate_left=True,
|
||||
quant_mode=1)
|
||||
x = torch_npu.npu_quant_matmul(
|
||||
x,
|
||||
self.down_proj.weight,
|
||||
self.down_proj.weight_scale,
|
||||
pertoken_scale=dynamic_scale,
|
||||
output_dtype=torch.bfloat16,
|
||||
)
|
||||
if self.down_proj.reduce_results and self.down_proj.tp_size > 1:
|
||||
x = tensor_model_parallel_all_reduce(x)
|
||||
return x
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
x = self.act_fn(gate_up)
|
||||
x, _ = self.down_proj(x)
|
||||
@ -169,6 +237,12 @@ class CustomDeepseekV2MoE(nn.Module):
|
||||
raise ValueError(f"Unsupported activation: {config.hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
|
||||
ascend_config = get_ascend_config()
|
||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||
# NOTE: multistream only effective when `VLLM_ENABLE_MC2` is on
|
||||
self.enable_multistream_moe = \
|
||||
ascend_config.torchair_graph_config.enable_multistream_moe and VLLM_ENABLE_MC2
|
||||
|
||||
self.gate = ReplicatedLinear(config.hidden_size,
|
||||
config.n_routed_experts,
|
||||
bias=False,
|
||||
@ -204,8 +278,11 @@ class CustomDeepseekV2MoE(nn.Module):
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
reduce_results=True,
|
||||
force_replicate=self.enable_multistream_moe,
|
||||
prefix=f"{prefix}.shared_experts",
|
||||
)
|
||||
else:
|
||||
self.shared_experts = None # type: ignore
|
||||
CustomDeepseekV2MoE.top_k = config.num_experts_per_tok
|
||||
|
||||
self.dp_size = get_dp_group().world_size
|
||||
@ -216,12 +293,6 @@ class CustomDeepseekV2MoE(nn.Module):
|
||||
|
||||
self.params_dtype = torch.get_default_dtype()
|
||||
|
||||
ascend_config = get_ascend_config()
|
||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||
# NOTE: multistream only effective when `VLLM_ENABLE_MC2` is on
|
||||
self.enable_multistream_shared_expert = \
|
||||
ascend_config.torchair_graph_config.enable_multistream_shared_expert and VLLM_ENABLE_MC2
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@ -240,12 +311,10 @@ class CustomDeepseekV2MoE(nn.Module):
|
||||
enable_force_load_balance = False
|
||||
if hasattr(attn_metadata, 'with_prefill_across_dp'):
|
||||
is_prefill = is_prefill or attn_metadata.with_prefill_across_dp
|
||||
|
||||
num_tokens, hidden_size = hidden_states.shape
|
||||
|
||||
multistream = self.enable_multistream_shared_expert and not is_prefill
|
||||
|
||||
old_hidden_states = hidden_states.clone()
|
||||
old_hidden_states = hidden_states
|
||||
use_separated_shared_experts = (self.shared_experts is not None
|
||||
and not self.enable_multistream_moe)
|
||||
|
||||
if self.tp_size > 1:
|
||||
if (VLLM_ENABLE_MC2
|
||||
@ -262,25 +331,22 @@ class CustomDeepseekV2MoE(nn.Module):
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
|
||||
kwargs = {}
|
||||
if multistream:
|
||||
kwargs.update({
|
||||
"shared_experts": self.shared_experts,
|
||||
"shared_hidden_states": old_hidden_states
|
||||
})
|
||||
|
||||
hidden_states = self.experts(
|
||||
experts_hidden_states = self.experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
is_prefill=is_prefill,
|
||||
top_k=CustomDeepseekV2MoE.top_k,
|
||||
enable_force_load_balance=enable_force_load_balance,
|
||||
**kwargs)
|
||||
shared_experts=(self.shared_experts
|
||||
if not use_separated_shared_experts else None),
|
||||
)
|
||||
|
||||
if multistream:
|
||||
hidden_states, shared_output = hidden_states
|
||||
|
||||
hidden_states = hidden_states * self.routed_scaling_factor
|
||||
if not isinstance(experts_hidden_states, tuple):
|
||||
hidden_states = experts_hidden_states * self.routed_scaling_factor
|
||||
else:
|
||||
hidden_states = (
|
||||
experts_hidden_states[0] * self.routed_scaling_factor +
|
||||
experts_hidden_states[1])
|
||||
|
||||
if self.tp_size > 1:
|
||||
if (VLLM_ENABLE_MC2
|
||||
@ -294,12 +360,9 @@ class CustomDeepseekV2MoE(nn.Module):
|
||||
else:
|
||||
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
||||
|
||||
if self.n_shared_experts is not None:
|
||||
if not multistream:
|
||||
shared_output = self.shared_experts(old_hidden_states)
|
||||
|
||||
if shared_output is not None:
|
||||
hidden_states = hidden_states + shared_output
|
||||
if use_separated_shared_experts:
|
||||
hidden_states = hidden_states + self.shared_experts(
|
||||
old_hidden_states)
|
||||
|
||||
return hidden_states.view(num_tokens, hidden_size)
|
||||
|
||||
|
@ -16,7 +16,7 @@
|
||||
# Adapted from vllm/tests/kernels/test_moe.py
|
||||
|
||||
import os
|
||||
from typing import Callable, List, Optional
|
||||
from typing import Any, Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -36,6 +36,7 @@ import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group
|
||||
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
|
||||
from vllm_ascend.utils import npu_stream_switch, npu_wait_tensor
|
||||
|
||||
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
|
||||
USING_LCCL_COM: bool = envs_ascend.USING_LCCL_COM
|
||||
@ -106,7 +107,8 @@ def process_topk_ids(topk_ids: torch.Tensor, expert_num: int, ep_size: int,
|
||||
return topk_ids_pad, unpad_indices
|
||||
|
||||
|
||||
def fused_experts_with_mc2(hidden_states: torch.Tensor,
|
||||
def fused_experts_with_mc2(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
@ -114,7 +116,8 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor,
|
||||
top_k: int,
|
||||
expert_map: torch.Tensor = None,
|
||||
moe_all_to_all_group_name: Optional[str] = None,
|
||||
**kwargs) -> torch.Tensor:
|
||||
shared_experts: Optional[Any] = None
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
global_bs = 0
|
||||
moe_expert_num = len(expert_map)
|
||||
kwargs_mc2 = {
|
||||
@ -154,6 +157,13 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor,
|
||||
expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[
|
||||
0:5]
|
||||
|
||||
if shared_experts is not None:
|
||||
with npu_stream_switch("moe_secondary", 0):
|
||||
npu_wait_tensor(hidden_states, topk_weights)
|
||||
shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states)
|
||||
npu_wait_tensor(shared_gate_up, expand_x)
|
||||
shared_act = shared_experts.act_fn(shared_gate_up)
|
||||
|
||||
w1 = w1.transpose(1, 2)
|
||||
|
||||
group_list = expert_token_nums.to(torch.int64)
|
||||
@ -210,7 +220,13 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor,
|
||||
|
||||
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2)
|
||||
|
||||
if shared_experts is None:
|
||||
return hidden_states
|
||||
else:
|
||||
with npu_stream_switch("moe_secondary", 0):
|
||||
npu_wait_tensor(shared_act, down_out_list)
|
||||
shared_hidden_states, _ = shared_experts.down_proj(shared_act)
|
||||
return hidden_states, shared_hidden_states
|
||||
|
||||
|
||||
def apply_mlp(hidden_states_wrapper: List[torch.Tensor],
|
||||
@ -875,6 +891,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
is_prefill: bool = False,
|
||||
enable_force_load_balance: bool = False,
|
||||
shared_experts: Optional[Any] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
|
||||
@ -924,7 +941,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
top_k=top_k,
|
||||
expert_map=expert_map,
|
||||
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
|
||||
**kwargs)
|
||||
shared_experts=shared_experts)
|
||||
elif self.torchair_graph_enabled or get_ep_group().world_size == 1:
|
||||
return fused_experts(hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
@ -1053,9 +1070,6 @@ class AscendFusedMoE(FusedMoE):
|
||||
self.moe_parallel_config.ep_rank = get_ep_group().rank_in_group
|
||||
|
||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||
# NOTE: multistream only effective when `VLLM_ENABLE_MC2` is on
|
||||
self.enable_multistream_shared_expert = \
|
||||
ascend_config.torchair_graph_config.enable_multistream_shared_expert and VLLM_ENABLE_MC2
|
||||
|
||||
if self.scoring_func != "softmax" and not self.use_grouped_topk:
|
||||
raise ValueError("Only softmax scoring function is supported for "
|
||||
@ -1102,8 +1116,8 @@ class AscendFusedMoE(FusedMoE):
|
||||
router_logits: torch.Tensor,
|
||||
is_prefill: bool,
|
||||
enable_force_load_balance: bool = False,
|
||||
top_k=None,
|
||||
**kwargs):
|
||||
top_k: Optional[int] = None,
|
||||
shared_experts: Optional[Any] = None):
|
||||
assert self.quant_method is not None
|
||||
|
||||
if top_k:
|
||||
@ -1132,7 +1146,7 @@ class AscendFusedMoE(FusedMoE):
|
||||
hidden_states, router_logits)
|
||||
|
||||
# Matrix multiply.
|
||||
hidden_states = self.quant_method.apply(
|
||||
e_hidden_states = self.quant_method.apply(
|
||||
layer=self,
|
||||
x=hidden_states,
|
||||
router_logits=router_logits,
|
||||
@ -1150,36 +1164,39 @@ class AscendFusedMoE(FusedMoE):
|
||||
enable_force_load_balance=enable_force_load_balance,
|
||||
log2phy=self.log2phy,
|
||||
global_redundant_expert_num=self.global_redundant_expert_num,
|
||||
**kwargs)
|
||||
shared_experts=shared_experts,
|
||||
)
|
||||
|
||||
if self.enable_multistream_shared_expert and not is_prefill:
|
||||
hidden_states, shared_output = hidden_states
|
||||
if shared_experts is not None:
|
||||
# Provide dummy implementation of "non-separated" shared experts.
|
||||
if not isinstance(e_hidden_states, tuple):
|
||||
return e_hidden_states, shared_experts(hidden_states)
|
||||
else:
|
||||
return e_hidden_states
|
||||
|
||||
if self.dp_size > 1:
|
||||
if VLLM_ENABLE_MC2 and not is_prefill:
|
||||
...
|
||||
elif self.torchair_graph_enabled:
|
||||
if USING_LCCL_COM: # type: ignore
|
||||
hidden_states = dist._functional_collectives.reduce_scatter_tensor(
|
||||
hidden_states,
|
||||
e_hidden_states = dist._functional_collectives.reduce_scatter_tensor(
|
||||
e_hidden_states,
|
||||
"sum",
|
||||
scatter_dim=0,
|
||||
group=get_dp_group().device_group)
|
||||
elif self.torchair_graph_enabled and not is_prefill:
|
||||
hidden_states = dist._functional_collectives.reduce_scatter_tensor(
|
||||
hidden_states,
|
||||
e_hidden_states = dist._functional_collectives.reduce_scatter_tensor(
|
||||
e_hidden_states,
|
||||
"sum",
|
||||
scatter_dim=0,
|
||||
group=get_dp_group().device_group)
|
||||
else:
|
||||
hidden_states = get_ep_group().combine(hidden_states)
|
||||
e_hidden_states = get_ep_group().combine(e_hidden_states)
|
||||
|
||||
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
|
||||
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
||||
e_hidden_states = tensor_model_parallel_all_reduce(e_hidden_states)
|
||||
|
||||
if self.enable_multistream_shared_expert and not is_prefill:
|
||||
return hidden_states, shared_output
|
||||
return hidden_states
|
||||
return e_hidden_states
|
||||
|
||||
# ----------------------------------------- TBO-related --------------------------------------------
|
||||
|
||||
|
@ -15,19 +15,19 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch_npu
|
||||
import torchair as tng # type: ignore
|
||||
from vllm.distributed import GroupCoordinator, tensor_model_parallel_all_reduce
|
||||
from vllm.distributed import GroupCoordinator
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.distributed.parallel_state import get_ep_group
|
||||
from vllm_ascend.ops.fused_moe import select_experts
|
||||
from vllm_ascend.utils import dispose_tensor
|
||||
from vllm_ascend.utils import (dispose_tensor, npu_stream_switch,
|
||||
npu_wait_tensor)
|
||||
|
||||
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
|
||||
|
||||
@ -39,8 +39,7 @@ def apply_mlp(hidden_states: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
group_list: torch.Tensor,
|
||||
dynamic_scale: torch.Tensor = None,
|
||||
group_list_type: int = 1,
|
||||
**kwargs) -> torch.Tensor:
|
||||
group_list_type: int = 1) -> torch.Tensor:
|
||||
"""
|
||||
apply MLP: gate_up_proj -> swiglu -> down_proj
|
||||
|
||||
@ -74,23 +73,6 @@ def apply_mlp(hidden_states: torch.Tensor,
|
||||
else:
|
||||
pertoken_scale = dynamic_scale
|
||||
|
||||
shared_experts = kwargs.get('shared_experts', None)
|
||||
if shared_experts:
|
||||
shared_gate_up = kwargs.get('shared_gate_up', None)
|
||||
shared_dynamic_scale = kwargs.get('shared_dynamic_scale', None)
|
||||
with tng.scope.npu_stream_switch('cv'):
|
||||
tng.scope.npu_wait_tensor(shared_gate_up, hidden_states)
|
||||
shared_x, shared_dynamic_scale = torch_npu.npu_dequant_swiglu_quant(
|
||||
x=shared_gate_up,
|
||||
weight_scale=shared_experts.gate_up_proj.weight_scale_fp32,
|
||||
activation_scale=shared_dynamic_scale,
|
||||
bias=None,
|
||||
quant_scale=None,
|
||||
quant_offset=None,
|
||||
group_index=None,
|
||||
activate_left=True,
|
||||
quant_mode=1)
|
||||
|
||||
# gmm1: gate_up_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
@ -120,24 +102,11 @@ def apply_mlp(hidden_states: torch.Tensor,
|
||||
group_list=group_list,
|
||||
output_dtype=w2_scale.dtype)[0]
|
||||
|
||||
if shared_experts:
|
||||
with tng.scope.npu_stream_switch('cv'):
|
||||
tng.scope.npu_wait_tensor(shared_x, hidden_states)
|
||||
shared_output = torch_npu.npu_quant_matmul(
|
||||
shared_x,
|
||||
shared_experts.down_proj.weight,
|
||||
shared_experts.down_proj.weight_scale,
|
||||
pertoken_scale=shared_dynamic_scale,
|
||||
output_dtype=torch.bfloat16,
|
||||
)
|
||||
if shared_experts.down_proj.reduce_results and shared_experts.down_proj.tp_size > 1:
|
||||
shared_output = tensor_model_parallel_all_reduce(shared_output)
|
||||
if shared_experts:
|
||||
return hidden_states, shared_output
|
||||
return hidden_states
|
||||
|
||||
|
||||
def fused_experts_with_mc2(hidden_states: torch.Tensor,
|
||||
def fused_experts_with_mc2(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
@ -149,7 +118,8 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor,
|
||||
moe_all_to_all_group_name: str = "",
|
||||
log2phy: torch.Tensor = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
**kwargs) -> torch.Tensor:
|
||||
shared_experts: Optional[Any] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
if log2phy:
|
||||
topk_ids = log2phy[topk_ids]
|
||||
global_bs = 0
|
||||
@ -188,31 +158,17 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor,
|
||||
}
|
||||
kwargs_mc2.update(stage1_kwargs)
|
||||
|
||||
shared_experts = kwargs.get('shared_experts', None)
|
||||
if shared_experts:
|
||||
shared_hidden_states = kwargs.get('shared_hidden_states', None)
|
||||
with tng.scope.npu_stream_switch('cv'):
|
||||
tng.scope.npu_wait_tensor(shared_hidden_states, hidden_states)
|
||||
shared_x, shared_dynamic_scale = torch_npu.npu_dynamic_quant(
|
||||
shared_hidden_states)
|
||||
shared_gate_up = torch_npu.npu_quant_matmul(
|
||||
shared_x,
|
||||
shared_experts.gate_up_proj.weight,
|
||||
shared_experts.gate_up_proj.weight_scale,
|
||||
output_dtype=torch.int32,
|
||||
)
|
||||
kwargs.update({
|
||||
"shared_gate_up": shared_gate_up,
|
||||
"shared_dynamic_scale": shared_dynamic_scale,
|
||||
})
|
||||
|
||||
output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2)
|
||||
# comm_stream.wait_stream(torch.npu.current_stream())
|
||||
expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[
|
||||
0:5]
|
||||
|
||||
if quant_mode == 0:
|
||||
dynamic_scale = None
|
||||
if shared_experts is not None:
|
||||
with npu_stream_switch("moe_secondary", 0):
|
||||
npu_wait_tensor(hidden_states, topk_weights)
|
||||
shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states)
|
||||
npu_wait_tensor(shared_gate_up[0], expand_x)
|
||||
shared_act = shared_experts.act_fn(shared_gate_up)
|
||||
|
||||
# `expand_x` will be disposed in the `apply_mlp` function
|
||||
down_out_list = apply_mlp(expand_x,
|
||||
@ -221,12 +177,7 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor,
|
||||
w2,
|
||||
w2_scale,
|
||||
expert_token_nums,
|
||||
dynamic_scale=dynamic_scale,
|
||||
**kwargs)
|
||||
|
||||
multi_stream = isinstance(down_out_list, tuple)
|
||||
if multi_stream:
|
||||
down_out_list, shared_output = down_out_list
|
||||
dynamic_scale=dynamic_scale)
|
||||
|
||||
# moeCombine
|
||||
kwargs_mc2 = {
|
||||
@ -257,9 +208,13 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor,
|
||||
|
||||
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2)
|
||||
|
||||
if multi_stream:
|
||||
return hidden_states, shared_output
|
||||
if shared_experts is None:
|
||||
return hidden_states
|
||||
else:
|
||||
with npu_stream_switch("moe_secondary", 0):
|
||||
npu_wait_tensor(shared_act[0], down_out_list)
|
||||
shared_output, _ = shared_experts.down_proj(shared_act)
|
||||
return hidden_states, shared_output
|
||||
|
||||
|
||||
# currently expert parallelism implemented with all2all
|
||||
@ -541,21 +496,33 @@ class AscendW8A8DynamicLinearMethod:
|
||||
@staticmethod
|
||||
def apply(
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
tp_rank: Optional[int] = 0,
|
||||
) -> torch.Tensor:
|
||||
original_dtype = x.dtype
|
||||
# use ATB quantize
|
||||
quant_out, dynamic_scale = torch_npu.npu_dynamic_quant(x)
|
||||
return torch_npu.npu_quant_matmul(
|
||||
quant_out,
|
||||
config = getattr(layer, "_ascend_quant_config", {})
|
||||
if not isinstance(x, tuple):
|
||||
output_dtype = config.get("output_dtype", x.dtype)
|
||||
quantized_x, dynamic_scale = torch_npu.npu_dynamic_quant(x)
|
||||
else:
|
||||
assert "output_dtype" in config.keys(), (
|
||||
f"DynamicLinearMethod needs explicitly specified `output_dtype`"
|
||||
f"for pre-quantized input, got config [{config}]")
|
||||
output_dtype = config["output_dtype"]
|
||||
quantized_x, dynamic_scale = x
|
||||
pertoken_scale = (dynamic_scale
|
||||
if config.get("pertoken_scale", True) else None)
|
||||
|
||||
output = torch_npu.npu_quant_matmul(
|
||||
quantized_x,
|
||||
layer.weight,
|
||||
layer.weight_scale,
|
||||
pertoken_scale=dynamic_scale,
|
||||
pertoken_scale=pertoken_scale,
|
||||
bias=bias,
|
||||
output_dtype=original_dtype,
|
||||
output_dtype=output_dtype,
|
||||
)
|
||||
return ((output, dynamic_scale)
|
||||
if config.get("return_scale", False) else output)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
if self.transpose_weight:
|
||||
@ -650,6 +617,7 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
enable_force_load_balance: bool = True,
|
||||
log2phy: torch.Tensor = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
shared_experts: Optional[Any] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
assert router_logits.shape[
|
||||
@ -706,7 +674,7 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
|
||||
log2phy=log2phy,
|
||||
global_redundant_expert_num=global_redundant_expert_num,
|
||||
**kwargs)
|
||||
shared_experts=shared_experts)
|
||||
elif self.torchair_graph_enabled or self.ep_group.world_size == 1:
|
||||
return fused_experts(hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
|
@ -19,17 +19,26 @@
|
||||
|
||||
import atexit
|
||||
import math
|
||||
from contextlib import contextmanager
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from threading import Lock
|
||||
from typing import TYPE_CHECKING, List, Tuple
|
||||
|
||||
import torch
|
||||
import torchair # type: ignore[import] # noqa: F401
|
||||
from packaging.version import InvalidVersion, Version
|
||||
from torch_npu.npu.streams import Event
|
||||
from vllm.logger import logger
|
||||
|
||||
import vllm_ascend.envs as envs
|
||||
|
||||
try:
|
||||
# Recent release of torchair has moved these ops to `.scope`.
|
||||
from torchair.scope import npu_stream_switch as _npu_stream_switch
|
||||
from torchair.scope import npu_wait_tensor as _npu_wait_tensor
|
||||
except ImportError:
|
||||
from torchair.ops import NpuStreamSwitch as _npu_stream_switch
|
||||
from torchair.ops import npu_wait_tensor as _npu_wait_tensor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
else:
|
||||
@ -227,3 +236,14 @@ class ProfileExecuteDuration:
|
||||
durations[tag] = observe_start.elapsed_time(observe_end)
|
||||
|
||||
return durations
|
||||
|
||||
|
||||
def npu_stream_switch(tag: str, priority: int, *, enabled: bool = True):
|
||||
return _npu_stream_switch(tag, priority) if enabled else nullcontext()
|
||||
|
||||
|
||||
def npu_wait_tensor(self: torch.Tensor,
|
||||
dependency: torch.Tensor,
|
||||
*,
|
||||
enabled: bool = True):
|
||||
return _npu_wait_tensor(self, dependency) if enabled else self
|
||||
|
Reference in New Issue
Block a user