From 7bdc606677705d072c1dc45f050a3c3471d6d379 Mon Sep 17 00:00:00 2001 From: sdmyzlp <117554856+sdmyzlp@users.noreply.github.com> Date: Wed, 11 Jun 2025 09:18:38 +0800 Subject: [PATCH] Support multistream of shared experts in FusedMoE (#997) Contains on #1111 for completeness. ### What this PR does / why we need it? Implement multi-stream parallelism for MoE layers with shared experts, where computation of shared experts will be overlapped with expert token dispatch and combine. Also, when multi-stream is enabled, weights of shared experts will be force to replicate across all cards, regardless of any tensor parallelism configurations, to avoid AllReduce operations. With the expected overlaping being: ``` | shared gate_up | shared act | | shared down | | dispatch | routed gate_up, act, down | combine | ``` ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Tested on 1x16 910 node, with tailored 2 layer DSKv2. --------- Signed-off-by: sdmyzlp --- .github/workflows/vllm_ascend_test.yaml | 2 + docs/source/user_guide/additional_config.md | 4 +- mypy.ini | 3 + .../test_offline_inference_distributed.py | 19 +- tests/singlecard/test_ascend_config.py | 4 +- vllm_ascend/ascend_config.py | 4 +- vllm_ascend/models/deepseek_dbo.py | 110 +-------- vllm_ascend/models/deepseek_v2.py | 217 +++++++++++------- vllm_ascend/ops/fused_moe.py | 77 ++++--- vllm_ascend/quantization/w8a8_dynamic.py | 142 +++++------- vllm_ascend/utils.py | 22 +- 11 files changed, 296 insertions(+), 308 deletions(-) diff --git a/.github/workflows/vllm_ascend_test.yaml b/.github/workflows/vllm_ascend_test.yaml index 32524eeca..b02350234 100644 --- a/.github/workflows/vllm_ascend_test.yaml +++ b/.github/workflows/vllm_ascend_test.yaml @@ -188,6 +188,7 @@ jobs: VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_QwQ VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_topk + VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W8A8 VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/ --ignore=tests/multicard/test_ilama_lora_tp2.py --ignore=tests/multicard/test_offline_inference_distributed.py fi @@ -218,5 +219,6 @@ jobs: VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_QwQ VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_topk + VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W8A8 VLLM_USE_MODELSCOPE=True pytest -sv tests/multicard/ --ignore=tests/multicard/test_ilama_lora_tp2.py --ignore=tests/multicard/test_offline_inference_distributed.py fi diff --git a/docs/source/user_guide/additional_config.md b/docs/source/user_guide/additional_config.md index c6558f4b6..b769a31ab 100644 --- a/docs/source/user_guide/additional_config.md +++ b/docs/source/user_guide/additional_config.md @@ -39,11 +39,11 @@ The details of each config option are as follows: | Name | Type | Default | Description | | ---- | ---- | ------- | ----------- | | `enabled` | bool | `False` | Whether to enable torchair graph mode | +| `enable_multistream_moe`| bool | `False` | Whether to enable multistream shared expert | | `enable_view_optimize` | bool | `True` | Whether to enable torchair view optimization | | `use_cached_graph` | bool | `False` | Whether to use cached graph | | `graph_batch_sizes` | list[int] | `[]` | The batch size for torchair graph cache | | `graph_batch_sizes_init` | bool | `False` | Init graph batch size dynamically if `graph_batch_sizes` is empty | -| `enable_multistream_shared_expert`| bool | `False` | Whether to enable multistream shared expert | **ascend_scheduler_config** @@ -64,7 +64,7 @@ A full example of additional configuration is as follows: "use_cached_graph": true, "graph_batch_sizes": [1, 2, 4, 8], "graph_batch_sizes_init": false, - "enable_multistream_shared_expert": false + "enable_multistream_moe": false }, "ascend_scheduler_config": { "enabled": true, diff --git a/mypy.ini b/mypy.ini index 72b03de21..6fe8e6c29 100644 --- a/mypy.ini +++ b/mypy.ini @@ -6,6 +6,9 @@ warn_unused_configs = True [mypy-torch_npu.*] ignore_missing_imports = True +[mypy-torchair.*] +ignore_missing_imports = True + [mypy-transformers.*] ignore_missing_imports = True diff --git a/tests/multicard/test_offline_inference_distributed.py b/tests/multicard/test_offline_inference_distributed.py index dc02c4b97..f5ec2c872 100644 --- a/tests/multicard/test_offline_inference_distributed.py +++ b/tests/multicard/test_offline_inference_distributed.py @@ -23,7 +23,7 @@ Run `pytest tests/test_offline_inference.py`. import os from unittest.mock import patch -import vllm # noqa: F401 +from modelscope import snapshot_download # type: ignore from vllm import SamplingParams from tests.conftest import VllmRunner @@ -95,3 +95,20 @@ def test_models_distributed_DeepSeek_dbo(): distributed_executor_backend="mp", ) as vllm_model: vllm_model.generate(example_prompts, sampling_params) + + +def test_models_distributed_DeepSeek_W8A8(): + example_prompts = [ + "Hello, my name is", + ] + max_tokens = 5 + + with VllmRunner( + snapshot_download("vllm-ascend/DeepSeek-V2-Lite-W8A8"), + max_model_len=8192, + enforce_eager=True, + dtype="auto", + tensor_parallel_size=4, + quantization="ascend", + ) as vllm_model: + vllm_model.generate_greedy(example_prompts, max_tokens) diff --git a/tests/singlecard/test_ascend_config.py b/tests/singlecard/test_ascend_config.py index 484fe5f70..818745f30 100644 --- a/tests/singlecard/test_ascend_config.py +++ b/tests/singlecard/test_ascend_config.py @@ -58,7 +58,7 @@ def test_run_with_ascend_config(): "use_cached_graph": True, "graph_batch_sizes": [1, 2, 4, 8], "graph_batch_sizes_init": False, - "enable_multistream_shared_expert": True, + "enable_multistream_moe": True, }, "ascend_scheduler_config": { "enabled": True, @@ -79,7 +79,7 @@ def test_run_with_ascend_config(): 1, 2, 4, 8 ] assert not ascend_config.torchair_graph_config.graph_batch_sizes_init - assert ascend_config.torchair_graph_config.enable_multistream_shared_expert + assert ascend_config.torchair_graph_config.enable_multistream_moe assert ascend_config.ascend_scheduler_config.enabled assert ascend_config.ascend_scheduler_config.enable_chunked_prefill assert ascend_config.expert_tensor_parallel_size == 1 diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index 0c072c3a3..abb60392c 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -54,8 +54,8 @@ class TorchairGraphConfig: "graph_batch_sizes", []) self.graph_batch_sizes_init = torchair_graph_config.get( "graph_batch_sizes_init", False) - self.enable_multistream_shared_expert = torchair_graph_config.get( - "enable_multistream_shared_expert", False) + self.enable_multistream_moe = torchair_graph_config.get( + "enable_multistream_moe", False) self.enable_view_optimize = torchair_graph_config.get( "enable_view_optimize", True) diff --git a/vllm_ascend/models/deepseek_dbo.py b/vllm_ascend/models/deepseek_dbo.py index c3de6ae93..9db49cbff 100644 --- a/vllm_ascend/models/deepseek_dbo.py +++ b/vllm_ascend/models/deepseek_dbo.py @@ -29,7 +29,7 @@ from typing import Any, Dict, List, Optional, Union import torch import torch.distributed as dist -import torch_npu +import torch_npu # noqa: F401 import vllm.envs as envs from torch import nn from transformers import PretrainedConfig @@ -40,13 +40,10 @@ from vllm.distributed import (get_pp_group, get_tp_group, tensor_model_parallel_all_reduce) from vllm.distributed.parallel_state import get_dp_group from vllm.forward_context import get_forward_context -from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, ReplicatedLinear, - RowParallelLinear, - UnquantizedLinearMethod) + RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope @@ -67,6 +64,7 @@ from vllm.sequence import IntermediateTensors import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.models.deepseek_v2 import CustomDeepseekV2MLP from vllm_ascend.multistream.base import MSEventKey from vllm_ascend.multistream.context import ( advance_step_multistream_layer_context, get_multistream_comm_context, @@ -78,117 +76,17 @@ from vllm_ascend.multistream.metadata import (MultiStreamConfig, make_multistream_metadata_ds) from vllm_ascend.multistream.ms_split import compute_split_seq_index from vllm_ascend.ops.fused_moe import AscendFusedMoE -from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod from vllm_ascend.utils import dispose_tensor VLLM_ASCEND_ENABLE_DBO: bool = envs_ascend.VLLM_ASCEND_ENABLE_DBO VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2 -class CustomDeepseekDBOMLP(nn.Module): - - def __init__( - self, - hidden_size: int, - intermediate_size: int, - hidden_act: str, - quant_config: Optional[QuantizationConfig] = None, - reduce_results: bool = True, - prefix: str = "", - ) -> None: - super().__init__() - self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.gate_up_proj") - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - quant_config=quant_config, - reduce_results=reduce_results, - prefix=f"{prefix}.down_proj") - if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") - self.act_fn = SiluAndMul() - - # NOTE: `torch_npu.npu_dequant_swiglu_quant` can only be enabled in dynamic quant - self.is_dynamic_quant = not isinstance( - self.gate_up_proj.quant_method, - UnquantizedLinearMethod) and isinstance( - self.gate_up_proj.quant_method.quant_method, - AscendW8A8DynamicLinearMethod) - - def forward(self, x): - if self.is_dynamic_quant: - x, dynamic_scale = torch_npu.npu_dynamic_quant(x) - x = torch_npu.npu_quant_matmul( - x, - self.gate_up_proj.weight, - self.gate_up_proj.weight_scale, - output_dtype=torch.int32, - ) - x, dynamic_scale = torch_npu.npu_dequant_swiglu_quant( - x=x, - weight_scale=self.gate_up_proj.weight_scale_fp32, - activation_scale=dynamic_scale, - bias=None, - quant_scale=None, - quant_offset=None, - group_index=None, - activate_left=True, - quant_mode=1) - x = torch_npu.npu_quant_matmul( - x, - self.down_proj.weight, - self.down_proj.weight_scale, - pertoken_scale=dynamic_scale, - output_dtype=torch.bfloat16, - ) - if self.down_proj.reduce_results and self.down_proj.tp_size > 1: - x = tensor_model_parallel_all_reduce(x) - return x - gate_up, _ = self.gate_up_proj(x) - x = self.act_fn(gate_up) - x, _ = self.down_proj(x) - return x +class CustomDeepseekDBOMLP(CustomDeepseekV2MLP): def _forward_ms_mlp(self, x): current_ms_metadata = get_multistream_comm_context() assert current_ms_metadata is not None - if self.is_dynamic_quant: - x, dynamic_scale = torch_npu.npu_dynamic_quant(x) - x = torch_npu.npu_quant_matmul( - x, - self.gate_up_proj.weight, - self.gate_up_proj.weight_scale, - output_dtype=torch.int32, - ) - x, dynamic_scale = torch_npu.npu_dequant_swiglu_quant( - x=x, - weight_scale=self.gate_up_proj.weight_scale_fp32, - activation_scale=dynamic_scale, - bias=None, - quant_scale=None, - quant_offset=None, - group_index=None, - activate_left=True, - quant_mode=1) - x = torch_npu.npu_quant_matmul( - x, - self.down_proj.weight, - self.down_proj.weight_scale, - pertoken_scale=dynamic_scale, - output_dtype=torch.bfloat16, - ) - if self.down_proj.reduce_results and self.down_proj.tp_size > 1: - current_ms_metadata.before_comm_event.record() - with torch.npu.stream(current_ms_metadata.comm_stream): - current_ms_metadata.before_comm_event.wait() - x = tensor_model_parallel_all_reduce(x) - current_ms_metadata.after_comm_event.record() - return x gate_up, _ = self.gate_up_proj(x) x = self.act_fn(gate_up) current_ms_metadata.before_comm_event.record() diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index 96c76338b..a83ca4751 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -25,7 +25,7 @@ # # vllm-project/vllm/vllm/model_executor/models/deepseek_v2.py # """Inference-only DeepseekV2/DeepseekV3 model.""" -from typing import Any, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch import torch.distributed as dist @@ -69,12 +69,73 @@ import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.distributed.parallel_state import get_ep_group from vllm_ascend.ops.fused_moe import AscendFusedMoE +from vllm_ascend.quantization.quant_config import AscendLinearMethod from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod from vllm_ascend.utils import dispose_tensor VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2 +class CustomDeepseekV2SiluAndMul(SiluAndMul): + + def __init__(self, + *, + weight_scale: Optional[Callable[[], torch.Tensor]] = None): + super().__init__() + self.weight_scale = weight_scale + + def forward_oot(self, x: Union[torch.Tensor, Tuple[torch.Tensor, + torch.Tensor]]): + if isinstance(x, tuple): + assert self.weight_scale is not None + # For AscendW8A8DynamicLinearMethod: + # a dynamic scale is passed along with the quantized value. + quantized_x, dynamic_scale = x + return torch_npu.npu_dequant_swiglu_quant( + x=quantized_x, + weight_scale=self.weight_scale(), + activation_scale=dynamic_scale, + activate_left=True, + quant_mode=1) + else: + return super().forward_oot(x) + + +class CustomDeepseekV2MergedReplicatedLinear(ReplicatedLinear): + + def __init__( + self, + input_size: int, + output_sizes: list[int], + bias: bool = True, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + self.output_sizes = output_sizes + super().__init__(input_size, + sum(output_sizes), + bias=bias, + quant_config=quant_config, + prefix=prefix) + + def weight_loader(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor, loaded_shard_id: int): + # With no support for GGUF format yet. + assert not getattr(param, "is_gguf_weight", False) + assert not getattr(param, "is_gguf_weight_type", False) + + assert loaded_shard_id < len(self.output_sizes) + shard_offset = sum(self.output_sizes[:loaded_shard_id]) + shard_size = self.output_sizes[loaded_shard_id] + shard = param.data.narrow(param.output_dim, shard_offset, shard_size) + + assert shard.size() == loaded_weight.size(), ( + f"Tried to load weights of size {loaded_weight.size()}" + f"to a parameter shard of id {loaded_shard_id} size {shard.size()}" + ) + shard.copy_(loaded_weight) + + class CustomDeepseekV2MLP(nn.Module): def __init__( @@ -84,61 +145,68 @@ class CustomDeepseekV2MLP(nn.Module): hidden_act: str, quant_config: Optional[QuantizationConfig] = None, reduce_results: bool = True, + force_replicate: bool = False, prefix: str = "", ) -> None: super().__init__() - self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.gate_up_proj") - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - quant_config=quant_config, - reduce_results=reduce_results, - prefix=f"{prefix}.down_proj") + if not force_replicate: + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj") + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj") + else: + self.gate_up_proj = CustomDeepseekV2MergedReplicatedLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj") + self.down_proj = ReplicatedLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.down_proj") if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") - self.act_fn = SiluAndMul() - # NOTE: `torch_npu.npu_dequant_swiglu_quant` can only be enabled in dynamic quant - self.is_dynamic_quant = not isinstance( - self.gate_up_proj.quant_method, - UnquantizedLinearMethod) and isinstance( - self.gate_up_proj.quant_method.quant_method, - AscendW8A8DynamicLinearMethod) + quant_method = self.gate_up_proj.quant_method + if isinstance(quant_method, UnquantizedLinearMethod): + self.act_fn = CustomDeepseekV2SiluAndMul() + elif (isinstance(quant_method, AscendLinearMethod) and isinstance( + quant_method.quant_method, AscendW8A8DynamicLinearMethod)): + # TODO(sdmyzlp): Currently preserved as before: + # 1. The only quantization supported for silu is W8A8Dynamic + # 2. Output dtype of gate_up/down is fixed to be int32/bfloat16 + # + # Maybe one can implement a better and more general configuration + # scheme, e.g. by somehow passing around the tweaked `quant_config` + self.act_fn = CustomDeepseekV2SiluAndMul( + # Use lazy binding, for `weight_scale_fp32` is accessible + # only after `process_weights_after_loading`. + weight_scale=lambda: self.gate_up_proj.weight_scale_fp32) + # To be consumed by AscendW8A8DynamicLinearMethod.apply() + self.gate_up_proj._ascend_quant_config = { + "output_dtype": torch.int32, + "pertoken_scale": False, + "return_scale": True, + } + self.down_proj._ascend_quant_config = { + "output_dtype": torch.bfloat16, + "pertoken_scale": True, + "return_scale": False, + } + else: + raise NotImplementedError( + f"Quantization with [{type(quant_method)}] is NOT supported") def forward(self, x): - if self.is_dynamic_quant: - x, dynamic_scale = torch_npu.npu_dynamic_quant(x) - x = torch_npu.npu_quant_matmul( - x, - self.gate_up_proj.weight, - self.gate_up_proj.weight_scale, - output_dtype=torch.int32, - ) - x, dynamic_scale = torch_npu.npu_dequant_swiglu_quant( - x=x, - weight_scale=self.gate_up_proj.weight_scale_fp32, - activation_scale=dynamic_scale, - bias=None, - quant_scale=None, - quant_offset=None, - group_index=None, - activate_left=True, - quant_mode=1) - x = torch_npu.npu_quant_matmul( - x, - self.down_proj.weight, - self.down_proj.weight_scale, - pertoken_scale=dynamic_scale, - output_dtype=torch.bfloat16, - ) - if self.down_proj.reduce_results and self.down_proj.tp_size > 1: - x = tensor_model_parallel_all_reduce(x) - return x gate_up, _ = self.gate_up_proj(x) x = self.act_fn(gate_up) x, _ = self.down_proj(x) @@ -169,6 +237,12 @@ class CustomDeepseekV2MoE(nn.Module): raise ValueError(f"Unsupported activation: {config.hidden_act}. " "Only silu is supported for now.") + ascend_config = get_ascend_config() + self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + # NOTE: multistream only effective when `VLLM_ENABLE_MC2` is on + self.enable_multistream_moe = \ + ascend_config.torchair_graph_config.enable_multistream_moe and VLLM_ENABLE_MC2 + self.gate = ReplicatedLinear(config.hidden_size, config.n_routed_experts, bias=False, @@ -204,8 +278,11 @@ class CustomDeepseekV2MoE(nn.Module): hidden_act=config.hidden_act, quant_config=quant_config, reduce_results=True, + force_replicate=self.enable_multistream_moe, prefix=f"{prefix}.shared_experts", ) + else: + self.shared_experts = None # type: ignore CustomDeepseekV2MoE.top_k = config.num_experts_per_tok self.dp_size = get_dp_group().world_size @@ -216,12 +293,6 @@ class CustomDeepseekV2MoE(nn.Module): self.params_dtype = torch.get_default_dtype() - ascend_config = get_ascend_config() - self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled - # NOTE: multistream only effective when `VLLM_ENABLE_MC2` is on - self.enable_multistream_shared_expert = \ - ascend_config.torchair_graph_config.enable_multistream_shared_expert and VLLM_ENABLE_MC2 - def forward( self, hidden_states: torch.Tensor, @@ -240,12 +311,10 @@ class CustomDeepseekV2MoE(nn.Module): enable_force_load_balance = False if hasattr(attn_metadata, 'with_prefill_across_dp'): is_prefill = is_prefill or attn_metadata.with_prefill_across_dp - num_tokens, hidden_size = hidden_states.shape - - multistream = self.enable_multistream_shared_expert and not is_prefill - - old_hidden_states = hidden_states.clone() + old_hidden_states = hidden_states + use_separated_shared_experts = (self.shared_experts is not None + and not self.enable_multistream_moe) if self.tp_size > 1: if (VLLM_ENABLE_MC2 @@ -262,25 +331,22 @@ class CustomDeepseekV2MoE(nn.Module): # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) - kwargs = {} - if multistream: - kwargs.update({ - "shared_experts": self.shared_experts, - "shared_hidden_states": old_hidden_states - }) - - hidden_states = self.experts( + experts_hidden_states = self.experts( hidden_states=hidden_states, router_logits=router_logits, is_prefill=is_prefill, top_k=CustomDeepseekV2MoE.top_k, enable_force_load_balance=enable_force_load_balance, - **kwargs) + shared_experts=(self.shared_experts + if not use_separated_shared_experts else None), + ) - if multistream: - hidden_states, shared_output = hidden_states - - hidden_states = hidden_states * self.routed_scaling_factor + if not isinstance(experts_hidden_states, tuple): + hidden_states = experts_hidden_states * self.routed_scaling_factor + else: + hidden_states = ( + experts_hidden_states[0] * self.routed_scaling_factor + + experts_hidden_states[1]) if self.tp_size > 1: if (VLLM_ENABLE_MC2 @@ -294,12 +360,9 @@ class CustomDeepseekV2MoE(nn.Module): else: hidden_states = tensor_model_parallel_all_reduce(hidden_states) - if self.n_shared_experts is not None: - if not multistream: - shared_output = self.shared_experts(old_hidden_states) - - if shared_output is not None: - hidden_states = hidden_states + shared_output + if use_separated_shared_experts: + hidden_states = hidden_states + self.shared_experts( + old_hidden_states) return hidden_states.view(num_tokens, hidden_size) diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 25f3b05d5..d6115d35c 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -16,7 +16,7 @@ # Adapted from vllm/tests/kernels/test_moe.py import os -from typing import Callable, List, Optional +from typing import Any, Callable, List, Optional, Tuple, Union import torch import torch.distributed as dist @@ -36,6 +36,7 @@ import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer +from vllm_ascend.utils import npu_stream_switch, npu_wait_tensor VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2 USING_LCCL_COM: bool = envs_ascend.USING_LCCL_COM @@ -106,15 +107,17 @@ def process_topk_ids(topk_ids: torch.Tensor, expert_num: int, ep_size: int, return topk_ids_pad, unpad_indices -def fused_experts_with_mc2(hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - top_k: int, - expert_map: torch.Tensor = None, - moe_all_to_all_group_name: Optional[str] = None, - **kwargs) -> torch.Tensor: +def fused_experts_with_mc2( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + top_k: int, + expert_map: torch.Tensor = None, + moe_all_to_all_group_name: Optional[str] = None, + shared_experts: Optional[Any] = None +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: global_bs = 0 moe_expert_num = len(expert_map) kwargs_mc2 = { @@ -154,6 +157,13 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor, expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[ 0:5] + if shared_experts is not None: + with npu_stream_switch("moe_secondary", 0): + npu_wait_tensor(hidden_states, topk_weights) + shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states) + npu_wait_tensor(shared_gate_up, expand_x) + shared_act = shared_experts.act_fn(shared_gate_up) + w1 = w1.transpose(1, 2) group_list = expert_token_nums.to(torch.int64) @@ -210,7 +220,13 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor, hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2) - return hidden_states + if shared_experts is None: + return hidden_states + else: + with npu_stream_switch("moe_secondary", 0): + npu_wait_tensor(shared_act, down_out_list) + shared_hidden_states, _ = shared_experts.down_proj(shared_act) + return hidden_states, shared_hidden_states def apply_mlp(hidden_states_wrapper: List[torch.Tensor], @@ -875,6 +891,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): e_score_correction_bias: Optional[torch.Tensor] = None, is_prefill: bool = False, enable_force_load_balance: bool = False, + shared_experts: Optional[Any] = None, **kwargs, ) -> torch.Tensor: @@ -924,7 +941,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): top_k=top_k, expert_map=expert_map, moe_all_to_all_group_name=self.moe_all_to_all_group_name, - **kwargs) + shared_experts=shared_experts) elif self.torchair_graph_enabled or get_ep_group().world_size == 1: return fused_experts(hidden_states=x, w1=layer.w13_weight, @@ -1053,9 +1070,6 @@ class AscendFusedMoE(FusedMoE): self.moe_parallel_config.ep_rank = get_ep_group().rank_in_group self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled - # NOTE: multistream only effective when `VLLM_ENABLE_MC2` is on - self.enable_multistream_shared_expert = \ - ascend_config.torchair_graph_config.enable_multistream_shared_expert and VLLM_ENABLE_MC2 if self.scoring_func != "softmax" and not self.use_grouped_topk: raise ValueError("Only softmax scoring function is supported for " @@ -1102,8 +1116,8 @@ class AscendFusedMoE(FusedMoE): router_logits: torch.Tensor, is_prefill: bool, enable_force_load_balance: bool = False, - top_k=None, - **kwargs): + top_k: Optional[int] = None, + shared_experts: Optional[Any] = None): assert self.quant_method is not None if top_k: @@ -1132,7 +1146,7 @@ class AscendFusedMoE(FusedMoE): hidden_states, router_logits) # Matrix multiply. - hidden_states = self.quant_method.apply( + e_hidden_states = self.quant_method.apply( layer=self, x=hidden_states, router_logits=router_logits, @@ -1150,36 +1164,39 @@ class AscendFusedMoE(FusedMoE): enable_force_load_balance=enable_force_load_balance, log2phy=self.log2phy, global_redundant_expert_num=self.global_redundant_expert_num, - **kwargs) + shared_experts=shared_experts, + ) - if self.enable_multistream_shared_expert and not is_prefill: - hidden_states, shared_output = hidden_states + if shared_experts is not None: + # Provide dummy implementation of "non-separated" shared experts. + if not isinstance(e_hidden_states, tuple): + return e_hidden_states, shared_experts(hidden_states) + else: + return e_hidden_states if self.dp_size > 1: if VLLM_ENABLE_MC2 and not is_prefill: ... elif self.torchair_graph_enabled: if USING_LCCL_COM: # type: ignore - hidden_states = dist._functional_collectives.reduce_scatter_tensor( - hidden_states, + e_hidden_states = dist._functional_collectives.reduce_scatter_tensor( + e_hidden_states, "sum", scatter_dim=0, group=get_dp_group().device_group) elif self.torchair_graph_enabled and not is_prefill: - hidden_states = dist._functional_collectives.reduce_scatter_tensor( - hidden_states, + e_hidden_states = dist._functional_collectives.reduce_scatter_tensor( + e_hidden_states, "sum", scatter_dim=0, group=get_dp_group().device_group) else: - hidden_states = get_ep_group().combine(hidden_states) + e_hidden_states = get_ep_group().combine(e_hidden_states) if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): - hidden_states = tensor_model_parallel_all_reduce(hidden_states) + e_hidden_states = tensor_model_parallel_all_reduce(e_hidden_states) - if self.enable_multistream_shared_expert and not is_prefill: - return hidden_states, shared_output - return hidden_states + return e_hidden_states # ----------------------------------------- TBO-related -------------------------------------------- diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index c6e863ff7..66a0a302c 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -15,19 +15,19 @@ # limitations under the License. # -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict, Optional, Tuple, Union import torch import torch.distributed as dist import torch_npu -import torchair as tng # type: ignore -from vllm.distributed import GroupCoordinator, tensor_model_parallel_all_reduce +from vllm.distributed import GroupCoordinator import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.distributed.parallel_state import get_ep_group from vllm_ascend.ops.fused_moe import select_experts -from vllm_ascend.utils import dispose_tensor +from vllm_ascend.utils import (dispose_tensor, npu_stream_switch, + npu_wait_tensor) VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2 @@ -39,8 +39,7 @@ def apply_mlp(hidden_states: torch.Tensor, w2_scale: torch.Tensor, group_list: torch.Tensor, dynamic_scale: torch.Tensor = None, - group_list_type: int = 1, - **kwargs) -> torch.Tensor: + group_list_type: int = 1) -> torch.Tensor: """ apply MLP: gate_up_proj -> swiglu -> down_proj @@ -74,23 +73,6 @@ def apply_mlp(hidden_states: torch.Tensor, else: pertoken_scale = dynamic_scale - shared_experts = kwargs.get('shared_experts', None) - if shared_experts: - shared_gate_up = kwargs.get('shared_gate_up', None) - shared_dynamic_scale = kwargs.get('shared_dynamic_scale', None) - with tng.scope.npu_stream_switch('cv'): - tng.scope.npu_wait_tensor(shared_gate_up, hidden_states) - shared_x, shared_dynamic_scale = torch_npu.npu_dequant_swiglu_quant( - x=shared_gate_up, - weight_scale=shared_experts.gate_up_proj.weight_scale_fp32, - activation_scale=shared_dynamic_scale, - bias=None, - quant_scale=None, - quant_offset=None, - group_index=None, - activate_left=True, - quant_mode=1) - # gmm1: gate_up_proj hidden_states = torch_npu.npu_grouped_matmul( x=[hidden_states], @@ -120,36 +102,24 @@ def apply_mlp(hidden_states: torch.Tensor, group_list=group_list, output_dtype=w2_scale.dtype)[0] - if shared_experts: - with tng.scope.npu_stream_switch('cv'): - tng.scope.npu_wait_tensor(shared_x, hidden_states) - shared_output = torch_npu.npu_quant_matmul( - shared_x, - shared_experts.down_proj.weight, - shared_experts.down_proj.weight_scale, - pertoken_scale=shared_dynamic_scale, - output_dtype=torch.bfloat16, - ) - if shared_experts.down_proj.reduce_results and shared_experts.down_proj.tp_size > 1: - shared_output = tensor_model_parallel_all_reduce(shared_output) - if shared_experts: - return hidden_states, shared_output return hidden_states -def fused_experts_with_mc2(hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - w1_scale: torch.Tensor, - w2_scale: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - top_k: int, - expert_map: torch.Tensor = None, - moe_all_to_all_group_name: str = "", - log2phy: torch.Tensor = None, - global_redundant_expert_num: int = 0, - **kwargs) -> torch.Tensor: +def fused_experts_with_mc2( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + top_k: int, + expert_map: torch.Tensor = None, + moe_all_to_all_group_name: str = "", + log2phy: torch.Tensor = None, + global_redundant_expert_num: int = 0, + shared_experts: Optional[Any] = None, +) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: if log2phy: topk_ids = log2phy[topk_ids] global_bs = 0 @@ -188,31 +158,17 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor, } kwargs_mc2.update(stage1_kwargs) - shared_experts = kwargs.get('shared_experts', None) - if shared_experts: - shared_hidden_states = kwargs.get('shared_hidden_states', None) - with tng.scope.npu_stream_switch('cv'): - tng.scope.npu_wait_tensor(shared_hidden_states, hidden_states) - shared_x, shared_dynamic_scale = torch_npu.npu_dynamic_quant( - shared_hidden_states) - shared_gate_up = torch_npu.npu_quant_matmul( - shared_x, - shared_experts.gate_up_proj.weight, - shared_experts.gate_up_proj.weight_scale, - output_dtype=torch.int32, - ) - kwargs.update({ - "shared_gate_up": shared_gate_up, - "shared_dynamic_scale": shared_dynamic_scale, - }) - output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2) # comm_stream.wait_stream(torch.npu.current_stream()) expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[ 0:5] - if quant_mode == 0: - dynamic_scale = None + if shared_experts is not None: + with npu_stream_switch("moe_secondary", 0): + npu_wait_tensor(hidden_states, topk_weights) + shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states) + npu_wait_tensor(shared_gate_up[0], expand_x) + shared_act = shared_experts.act_fn(shared_gate_up) # `expand_x` will be disposed in the `apply_mlp` function down_out_list = apply_mlp(expand_x, @@ -221,12 +177,7 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor, w2, w2_scale, expert_token_nums, - dynamic_scale=dynamic_scale, - **kwargs) - - multi_stream = isinstance(down_out_list, tuple) - if multi_stream: - down_out_list, shared_output = down_out_list + dynamic_scale=dynamic_scale) # moeCombine kwargs_mc2 = { @@ -257,9 +208,13 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor, hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2) - if multi_stream: + if shared_experts is None: + return hidden_states + else: + with npu_stream_switch("moe_secondary", 0): + npu_wait_tensor(shared_act[0], down_out_list) + shared_output, _ = shared_experts.down_proj(shared_act) return hidden_states, shared_output - return hidden_states # currently expert parallelism implemented with all2all @@ -541,21 +496,33 @@ class AscendW8A8DynamicLinearMethod: @staticmethod def apply( layer: torch.nn.Module, - x: torch.Tensor, + x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], bias: Optional[torch.Tensor] = None, tp_rank: Optional[int] = 0, ) -> torch.Tensor: - original_dtype = x.dtype - # use ATB quantize - quant_out, dynamic_scale = torch_npu.npu_dynamic_quant(x) - return torch_npu.npu_quant_matmul( - quant_out, + config = getattr(layer, "_ascend_quant_config", {}) + if not isinstance(x, tuple): + output_dtype = config.get("output_dtype", x.dtype) + quantized_x, dynamic_scale = torch_npu.npu_dynamic_quant(x) + else: + assert "output_dtype" in config.keys(), ( + f"DynamicLinearMethod needs explicitly specified `output_dtype`" + f"for pre-quantized input, got config [{config}]") + output_dtype = config["output_dtype"] + quantized_x, dynamic_scale = x + pertoken_scale = (dynamic_scale + if config.get("pertoken_scale", True) else None) + + output = torch_npu.npu_quant_matmul( + quantized_x, layer.weight, layer.weight_scale, - pertoken_scale=dynamic_scale, + pertoken_scale=pertoken_scale, bias=bias, - output_dtype=original_dtype, + output_dtype=output_dtype, ) + return ((output, dynamic_scale) + if config.get("return_scale", False) else output) def process_weights_after_loading(self, layer): if self.transpose_weight: @@ -650,6 +617,7 @@ class AscendW8A8DynamicFusedMoEMethod: enable_force_load_balance: bool = True, log2phy: torch.Tensor = None, global_redundant_expert_num: int = 0, + shared_experts: Optional[Any] = None, **kwargs, ) -> torch.Tensor: assert router_logits.shape[ @@ -706,7 +674,7 @@ class AscendW8A8DynamicFusedMoEMethod: moe_all_to_all_group_name=self.moe_all_to_all_group_name, log2phy=log2phy, global_redundant_expert_num=global_redundant_expert_num, - **kwargs) + shared_experts=shared_experts) elif self.torchair_graph_enabled or self.ep_group.world_size == 1: return fused_experts(hidden_states=x, w1=layer.w13_weight, diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 7d4093804..f41dab4b9 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -19,17 +19,26 @@ import atexit import math -from contextlib import contextmanager +from contextlib import contextmanager, nullcontext from threading import Lock from typing import TYPE_CHECKING, List, Tuple import torch +import torchair # type: ignore[import] # noqa: F401 from packaging.version import InvalidVersion, Version from torch_npu.npu.streams import Event from vllm.logger import logger import vllm_ascend.envs as envs +try: + # Recent release of torchair has moved these ops to `.scope`. + from torchair.scope import npu_stream_switch as _npu_stream_switch + from torchair.scope import npu_wait_tensor as _npu_wait_tensor +except ImportError: + from torchair.ops import NpuStreamSwitch as _npu_stream_switch + from torchair.ops import npu_wait_tensor as _npu_wait_tensor + if TYPE_CHECKING: from vllm.config import VllmConfig else: @@ -227,3 +236,14 @@ class ProfileExecuteDuration: durations[tag] = observe_start.elapsed_time(observe_end) return durations + + +def npu_stream_switch(tag: str, priority: int, *, enabled: bool = True): + return _npu_stream_switch(tag, priority) if enabled else nullcontext() + + +def npu_wait_tensor(self: torch.Tensor, + dependency: torch.Tensor, + *, + enabled: bool = True): + return _npu_wait_tensor(self, dependency) if enabled else self