diff --git a/tests/e2e/multicard/test_offline_inference_distributed.py b/tests/e2e/multicard/test_offline_inference_distributed.py index f3348d8d5..be62e1b6c 100644 --- a/tests/e2e/multicard/test_offline_inference_distributed.py +++ b/tests/e2e/multicard/test_offline_inference_distributed.py @@ -166,7 +166,7 @@ def test_sp_for_qwen3_moe() -> None: @pytest.mark.parametrize("enforce_eager", [True, False]) @pytest.mark.parametrize("model", QWEN_DENSE_MODELS) @patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE": "1"}) -@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM": "1"}) +@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"}) def test_models_distributed_Qwen_Dense_with_flashcomm_v1(model, enforce_eager): example_prompts = [ "Hello, my name is", diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py index 25f543e90..d6dd09161 100644 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -500,9 +500,12 @@ class TestAscendMLAImpl(TestBase): mock_up_proj.assert_called_once() mock_npu_fused_infer_attention_score.assert_called_once() + @patch("torch.ops.vllm.maybe_all_gather_and_maybe_unpad") @patch("vllm_ascend.attention.mla_v1.maybe_npu_prefetch") - def test_mla_preprocess(self, magic_npu_fetch): + def test_mla_preprocess(self, magic_npu_fetch, + mock_maybe_all_gather_and_maybe_unpad): magic_npu_fetch.return_value = MagicMock() + mock_maybe_all_gather_and_maybe_unpad.side_effect = lambda x, label: x batch_size = 4 seq_len = 8 hidden_size = 1024 diff --git a/tests/ut/models/test_deepseek_v2.py b/tests/ut/models/test_deepseek_v2.py index e33a39a13..cf97c6216 100644 --- a/tests/ut/models/test_deepseek_v2.py +++ b/tests/ut/models/test_deepseek_v2.py @@ -42,9 +42,11 @@ def test_row_parallel_linear(cls, mock_distributed): assert output[0].shape == (2, 4, 64) +@patch("vllm_ascend.models.layers.mla.get_forward_context") @patch("torch.ops.vllm.mla_forward") @patch("torch_npu.npu_rms_norm") def test_custom_deepseek_v2_mla_attention(mock_rms_norm, mock_mla_forward, + mock_forward_context, mock_distributed, base_config): mock_rms_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128)) # Make a fake ascend config because of the AscendLinearBase @@ -54,6 +56,9 @@ def test_custom_deepseek_v2_mla_attention(mock_rms_norm, mock_mla_forward, vllm_config.parallel_config.tensor_parallel_size = 1 vllm_config.kv_transfer_config = None ascend_config.init_ascend_config(vllm_config) + dummy_forward_context = MagicMock() + dummy_forward_context.sp_enabled = False + mock_forward_context.return_value = dummy_forward_context attn = CustomDeepseekV2MLAAttention(config=base_config, hidden_size=128, diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index ade8268a0..d3402fea1 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -11,7 +11,7 @@ from vllm.forward_context import (BatchDescriptor, get_forward_context, set_forward_context) import vllm_ascend.envs as envs_ascend -from vllm_ascend.utils import enable_sp +from vllm_ascend.utils import enable_sp, is_moe_model if TYPE_CHECKING: from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod @@ -112,15 +112,20 @@ def set_ascend_forward_context( # Currently, it is an empirical value. In normal scenarios, if the concurrency exceeds this threshold, # the performance benefits can be maximized. Conversely, if the concurrency is below the threshold, # the performance may degrade due to the switching of communication methods. - sp_enabled = enable_sp(vllm_config) and \ - tp_world_size > 1 and \ - num_tokens is not None and num_tokens > 1000 + if is_moe_model(vllm_config): + sp_enabled = enable_sp(vllm_config) and \ + tp_world_size > 1 + else: + sp_enabled = enable_sp(vllm_config) and \ + tp_world_size > 1 and \ + num_tokens is not None and num_tokens > 1000 if sp_enabled: pad_size = (tp_world_size - (num_tokens % tp_world_size)) % tp_world_size forward_context.pad_size = pad_size forward_context.sp_enabled = sp_enabled + forward_context.num_tokens = num_tokens # set this for rope forward_oot using forward_context.is_first_layer = True @@ -169,8 +174,14 @@ def set_ascend_forward_context( dp_world_size = get_dp_group().world_size if dp_world_size > 1 and forward_context.dp_metadata is not None: - max_tokens_across_dp = forward_context.dp_metadata.max_tokens_across_dp_cpu.item( - ) + max_tokens_across_dp = \ + forward_context.dp_metadata.max_tokens_across_dp_cpu.item() + if sp_enabled: + padded_length = (max_tokens_across_dp + tp_world_size - + 1) // tp_world_size * tp_world_size + pad_size = padded_length - num_tokens + forward_context.padded_length = padded_length + forward_context.pad_size = pad_size else: max_tokens_across_dp = num_tokens diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 819edcbb9..3fdc7d071 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -9,7 +9,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionMetadata, MLAAttentionImpl) from vllm.config import VllmConfig, get_current_vllm_config -from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.forward_context import ForwardContext, get_forward_context from vllm.model_executor.layers.linear import (LinearBase, UnquantizedLinearMethod) @@ -1128,10 +1128,11 @@ class AscendMLAImpl(MLAAttentionImpl): q_c = hidden_states kv_no_split = self.kv_a_proj_with_mqa(hidden_states)[0] - # Process for shared_expert_dp - if need_gather_q_kv: - q_c = get_tp_group().all_gather(q_c, 0) - kv_no_split = get_tp_group().all_gather(kv_no_split, 0) + # Process for Flash Comm V1 + q_c = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( + q_c, need_gather_q_kv) + kv_no_split = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( + kv_no_split, need_gather_q_kv) decode_preprocess_res = None prefill_preprocess_res = None if has_prefill: @@ -1200,8 +1201,7 @@ class AscendMLAImpl(MLAAttentionImpl): num_decode_tokens = attn_metadata.num_decode_tokens # Inputs and outputs may be padded for CUDA graphs output_padded = output - output = output[:num_actual_tokens, ...] - o_proj_input_shape = (num_actual_tokens, + o_proj_input_shape = (get_forward_context().num_tokens, self.num_heads * self.v_head_dim) o_proj_input = torch.empty(o_proj_input_shape, dtype=hidden_states.dtype, @@ -1248,7 +1248,8 @@ class AscendMLAImpl(MLAAttentionImpl): o_proj_input[num_decode_tokens:] = output_prefill current_ms_metadata.after_comm_event.record() else: - o_proj_input[num_decode_tokens:] = output_prefill + o_proj_input[ + num_decode_tokens:num_actual_tokens] = output_prefill # O proj current_ms_metadata = get_multistream_comm_context() MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 @@ -1258,20 +1259,14 @@ class AscendMLAImpl(MLAAttentionImpl): max_size=MAX_O_PROJ_PREFETCH_SIZE, enabled=self.enable_prefetch) - output[...] = self.o_proj( - o_proj_input, - is_prefill=prefill_preprocess_res is not None, - is_force_scatter=self.enable_shared_expert_dp)[0] + output[...] = self.o_proj(o_proj_input)[0] else: with torch.npu.stream(current_ms_metadata.comm_stream): maybe_npu_prefetch(inputs=self.o_proj.weight, dependency=o_proj_input, max_size=MAX_O_PROJ_PREFETCH_SIZE, enabled=self.enable_prefetch) - output[...] = self.o_proj( - o_proj_input, - is_prefill=prefill_preprocess_res is not None, - is_force_scatter=self.enable_shared_expert_dp)[0] + output[...] = self.o_proj(o_proj_input)[0] current_ms_metadata.after_comm_event.record() del o_proj_input diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index 2a0049675..db149acb2 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -133,8 +133,8 @@ env_variables: Dict[str, Callable[[], Any]] = { lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE", '0'))), # Whether to enable FlashComm optimization when tensor parallel is enabled. # This feature will get better performance when concurrency is large. - "VLLM_ASCEND_ENABLE_FLASHCOMM": - lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", '0'))), + "VLLM_ASCEND_ENABLE_FLASHCOMM1": + lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM1", '0'))), # Whether to enable MLP weight prefetch, only used in small concurrency. "VLLM_ASCEND_ENABLE_PREFETCH_MLP": lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_PREFETCH_MLP", '0'))), diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index bc6152231..573f48e7d 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -300,12 +300,11 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention): bias=False, quant_config=quant_config, prefix=f"{prefix}.kv_b_proj") - self.o_proj = CustomDeepseekV2RowParallelLinear( - self.num_heads * self.v_head_dim, - self.hidden_size, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.o_proj") + self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj") if rope_scaling: rope_scaling["rope_type"] = 'deepseek_yarn' diff --git a/vllm_ascend/models/layers/mla.py b/vllm_ascend/models/layers/mla.py index e0c7e2dbf..752d1d1b9 100644 --- a/vllm_ascend/models/layers/mla.py +++ b/vllm_ascend/models/layers/mla.py @@ -122,19 +122,8 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttention): hidden_states: torch.Tensor, kv_cache: Optional[torch.Tensor] = None, attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: - num_tokens = hidden_states.shape[0] - need_gather_q_kv = False - if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers: - # Simulate all gather to calculate output shape - num_tokens = num_tokens * self.tp_size - need_gather_q_kv = True - if not self.enable_shared_expert_dp or self.debug_layer_idx < self.first_k_dense_replace: - output_shape = hidden_states.shape - else: - rows = num_tokens // self.tp_size - if num_tokens % self.tp_size: - rows += 1 - output_shape = (rows, hidden_states.shape[1]) + need_gather_q_kv = get_forward_context().sp_enabled + output_shape = hidden_states.shape # FIXME: This does not seem right, should make sure the buffer is fixed output = torch.empty(output_shape, dtype=hidden_states.dtype, diff --git a/vllm_ascend/ops/common_fused_moe.py b/vllm_ascend/ops/common_fused_moe.py index c1a32c660..78183acdc 100644 --- a/vllm_ascend/ops/common_fused_moe.py +++ b/vllm_ascend/ops/common_fused_moe.py @@ -38,8 +38,9 @@ from vllm_ascend.eplb.core.eplb_utils import (determine_default_expert_map, from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer from vllm_ascend.ops.moe.experts_selector import select_experts from vllm_ascend.ops.moe.moe_comm_method import setup_moe_comm_method -from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, is_310p, is_enable_nz, - npu_stream_switch) +from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, enable_sp, is_310p, + is_enable_nz, npu_stream_switch, + shared_expert_dp_enabled) class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): @@ -417,6 +418,10 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE): self.multistream_overlap_shared_expert = ascend_config.multistream_overlap_shared_expert if self.multistream_overlap_shared_expert: self.shared_expert_stream = torch.npu.Stream() + if enable_sp(): + logger.info_once( + "Sequence parallelism is enabled, shared experts are replicated for best performance." + ) def forward( self, @@ -444,7 +449,8 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE): # NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel` forward_context = get_forward_context() moe_comm_type = forward_context.moe_comm_type - if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2}: + if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2} \ + and not shared_expert_dp_enabled(): shared_out = tensor_model_parallel_all_reduce(shared_out) fused_output = AscendFusedMoE.forward_impl( self, diff --git a/vllm_ascend/ops/linear_op.py b/vllm_ascend/ops/linear_op.py index 663d28eff..58a96fdec 100644 --- a/vllm_ascend/ops/linear_op.py +++ b/vllm_ascend/ops/linear_op.py @@ -49,7 +49,7 @@ from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group, get_otp_group) from vllm_ascend.utils import (dense_optim_enable, enable_sp, matmul_allreduce_enable, mlp_tp_enable, - oproj_tp_enable) + oproj_tp_enable, shared_expert_dp_enabled) class CustomLinearOp: @@ -418,7 +418,8 @@ def _get_row_parallel_op( def get_parallel_op(disable_tp, prefix, layer, direct): - if disable_tp: + if disable_tp or ("shared_experts" in prefix + and shared_expert_dp_enabled()): return None, 0, 1 custom_op: Optional[Union[MLPColumnParallelOp, SequenceColumnParallelOp, MLPRowParallelOp, OProjRowParallelOp, diff --git a/vllm_ascend/ops/moe/fused_moe_prepare_and_finalize.py b/vllm_ascend/ops/moe/fused_moe_prepare_and_finalize.py index 415c39637..7533ccebe 100644 --- a/vllm_ascend/ops/moe/fused_moe_prepare_and_finalize.py +++ b/vllm_ascend/ops/moe/fused_moe_prepare_and_finalize.py @@ -27,7 +27,7 @@ from vllm.distributed.parallel_state import ( from vllm.forward_context import get_forward_context from vllm.model_executor.layers.fused_moe import FusedMoEConfig -from vllm_ascend.utils import get_rm_router_logits_state +from vllm_ascend.utils import enable_sp, get_rm_router_logits_state class FusedMoEPrepareAndFinalize(ABC): @@ -198,7 +198,7 @@ class FusedMoEPrepareAndFinalizeWithAll2All(FusedMoEPrepareAndFinalize): class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalizeWithAll2All): """ MoE communication strategy using MC2, which is based on All2All. Hence, it inherits - All2All and share the same finalize method. + All2All and share the same finalize method. Designed for Ascend or environments requiring explicit padding and slicing control. Relies on `mc2_mask` and `padded_num_tokens` from forward_context for alignment. """ @@ -277,9 +277,24 @@ class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalizeWithAll2All): class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize): """ - MoE communication strategy using All-Gather + Reduce-Scatter. - Designed for DP > 1: gather inputs across DP ranks before MoE, scatter outputs after. - Uses `max_tokens_across_dp` from forward_context for padding alignment. + MoE communication strategy using All-Gather + Reduce-Scatter on EP group. + There are two sets of prepare and finalize: + 1. _prepare_with_dp_group/_finalize_with_dp_group: When sequence parallelism is not enabled, + we gather inputs across DP ranks before MoE, scatter outputs after. + The communication and calculation process is as follows (AG, AR and RS + are abbreviations for All-Gather, All-Reduce and Reduce-Scatter, respectively): + + Attn → TP AR → DP AG → MoE → DP RS → TP AR + + 2. _prepare_with_ep_group/_finalize_with_ep_group: When sequence parallelism is enabled, + the above process becomes: + + TP AG → Attn → TP RS → TP AG → DP AG → MoE → DP RS → TP RS + + This strategy further combines TP AG + DP AG into EP All-Gather and TP RS + DP RS + into EP Reduce-Scatter to improve communication performance. The optimized process is as follows: + + TP AG → Attn → TP RS → EP AG → MoE → EP RS """ def prepare( @@ -289,6 +304,42 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize): enable_shared_expert_dp: bool = False, replace_allreduce: bool = False, gate=None + ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], + Optional[torch.Tensor]]: + """ + Preparation steps: + AllGather hidden_states and router_logits to form global tensors. + + Returns: + Tuple of (global_hidden_states, global_router_logits, None) + """ + if enable_sp(): + return self._prepare_with_ep_group(hidden_states, router_logits) + + return self._prepare_with_dp_group(hidden_states, router_logits, + enable_shared_expert_dp, + replace_allreduce, gate) + + def _prepare_with_ep_group( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], + Optional[torch.Tensor]]: + hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( + hidden_states, True, True) + router_logits = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( + router_logits, True, True) + + return hidden_states, router_logits, None, None + + def _prepare_with_dp_group( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + enable_shared_expert_dp: bool = False, + replace_allreduce: bool = False, + gate=None ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: """ @@ -301,7 +352,6 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize): Tuple of (global_hidden_states, global_router_logits, None, None) """ self.enable_shared_expert_dp = enable_shared_expert_dp - if self.moe_config.dp_size > 1: forward_context = get_forward_context() max_tokens_across_dp = forward_context.max_tokens_across_dp @@ -323,7 +373,6 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize): else: router_logits = self.moe_config.dp_group.all_gather( router_logits, 0) - return hidden_states, router_logits, None, None def finalize(self, @@ -331,6 +380,36 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize): reduce_results: bool, context_metadata: Optional[dict] = None) -> torch.Tensor: """ + Finalization steps: + Reduce Scatter hidden states. + + Returns: + Tensor with shape [local_num_tokens, hidden_size] + """ + if enable_sp(): + return self._finalize_with_ep_group(hidden_states) + + return self._finalize_with_dp_group(hidden_states, reduce_results) + + def _finalize_with_ep_group(self, + hidden_states: torch.Tensor) -> torch.Tensor: + """ + Argument `reduce_results` is not needed in this func. Given sequence parallelism is enabled: + 1. Reduce_results is False usually happens when models have shared experts and need to + allreduce hidden states after results of shared experts and routed experts are added in FusedMoe. + We do reduce scatter for hidden states here, then skip allreudce in FusedMoe and add it to the + result of shared experts. + 2 Reduce_results is True usually happens when model has no shared experts. We still do reduce scatter + here, then skip allreudce in FusedMoe. + """ + hidden_states = torch.ops.vllm.maybe_pad_and_reduce( + hidden_states, True) + + return hidden_states + + def _finalize_with_dp_group(self, hidden_states: torch.Tensor, + reduce_results: bool) -> torch.Tensor: + """ Finalization steps: 1. If DP > 1 and not shared expert, reduce-scatter output across DP group. 2. Slice to original local token count. diff --git a/vllm_ascend/ops/register_custom_ops.py b/vllm_ascend/ops/register_custom_ops.py index a66fcd33a..5e2bbcab8 100644 --- a/vllm_ascend/ops/register_custom_ops.py +++ b/vllm_ascend/ops/register_custom_ops.py @@ -1,7 +1,9 @@ import torch import torch.nn.functional as F import torch_npu -from vllm.distributed import (tensor_model_parallel_all_gather, +from vllm.distributed import (get_dp_group, get_ep_group, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce, tensor_model_parallel_reduce_scatter) from vllm.forward_context import get_forward_context @@ -13,8 +15,10 @@ from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch from vllm_ascend.utils import npu_stream_switch, prefetch_stream -def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor, - label: bool) -> torch.Tensor: +def _maybe_all_gather_and_maybe_unpad_impl( + x: torch.Tensor, + label: bool, + is_ep_comm: bool = False) -> torch.Tensor: try: forward_context = get_forward_context() except AssertionError: @@ -22,27 +26,66 @@ def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor, sp_enabled = forward_context.sp_enabled if sp_enabled and label: - x = tensor_model_parallel_all_gather(x, 0) - pad_size = forward_context.pad_size - if pad_size > 0: - x = x[:-pad_size, :] + dp_metadata = forward_context.dp_metadata + if dp_metadata is None or not is_ep_comm: + x = tensor_model_parallel_all_gather(x, 0) + pad_size = forward_context.pad_size + if pad_size > 0: + x = x[:-pad_size, :] + else: + x = get_ep_group().all_gather(x, 0) + # unpad + num_tokens_across_dp_cpu = dp_metadata.num_tokens_across_dp_cpu + result = torch.empty( + (num_tokens_across_dp_cpu.sum(), *x.shape[1:]), + device=x.device, + dtype=x.dtype) + dp_size = get_dp_group().world_size + x = x.view(dp_size, forward_context.padded_length, *x.shape[1:]) + offset = 0 + for idx in range(dp_size): + num_tokens_dp = num_tokens_across_dp_cpu[idx] + result[offset:offset + + num_tokens_dp, :] = x[idx, :num_tokens_dp, :] + offset += num_tokens_dp + x = result + return x -def _maybe_pad_and_reduce_impl(x: torch.Tensor) -> torch.Tensor: +def _maybe_pad_and_reduce_impl(x: torch.Tensor, + is_ep_comm: bool = False) -> torch.Tensor: try: forward_context = get_forward_context() except AssertionError: return tensor_model_parallel_all_reduce(x) - sp_enabled = forward_context.sp_enabled - if sp_enabled: + if not forward_context.sp_enabled: + return tensor_model_parallel_all_reduce(x) + + dp_metadata = forward_context.dp_metadata + if dp_metadata is None or not is_ep_comm: pad_size = forward_context.pad_size if pad_size > 0: x = F.pad(x, (0, 0, 0, pad_size)) return tensor_model_parallel_reduce_scatter(x, 0) else: - return tensor_model_parallel_all_reduce(x) + # padding + dp_size = get_dp_group().world_size + num_tokens_across_dp_cpu = \ + get_forward_context().dp_metadata.num_tokens_across_dp_cpu + padded_x = torch.empty( + (dp_size, forward_context.padded_length, *x.shape[1:]), + device=x.device, + dtype=x.dtype) + offset = 0 + for idx in range(dp_size): + num_tokens_dp = num_tokens_across_dp_cpu[idx] + padded_x[idx, :num_tokens_dp] = x[offset:offset + num_tokens_dp] + offset += num_tokens_dp + + return get_ep_group().reduce_scatter(padded_x.view(-1, *x.shape[1:]), + 0) def _maybe_prefetch_mlp_gate_up_proj_impl(x_dependency: torch.Tensor, @@ -71,6 +114,33 @@ def _maybe_prefetch_mlp_gate_up_proj_impl(x_dependency: torch.Tensor, return +def _maybe_all_gather_and_maybe_unpad_fake( + x: torch.Tensor, + label: bool, + is_ep_comm: bool = False) -> torch.Tensor: + + if get_forward_context().sp_enabled and label: + return torch.empty( + (x.shape[0] * get_tensor_model_parallel_world_size(), + *x.shape[1:]), + device=x.device, + dtype=x.dtype) + + return x + + +def _maybe_pad_and_reduce_fake(x: torch.Tensor, + is_ep_comm: bool = False) -> torch.Tensor: + if get_forward_context().sp_enabled: + return torch.empty( + (x.shape[0] // get_tensor_model_parallel_world_size(), + *x.shape[1:]), + device=x.device, + dtype=x.dtype) + + return x + + def _maybe_prefetch_mlp_gate_up_proj_impl_fake(x_dependency: torch.Tensor, prefix: str) -> None: return @@ -158,7 +228,8 @@ def _maybe_all_reduce_tensor_model_parallel_impl( final_hidden_states: torch.Tensor) -> torch.Tensor: forward_context = get_forward_context() moe_comm_type = forward_context.moe_comm_type - if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2}: + if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2 + } or forward_context.sp_enabled: return final_hidden_states else: return tensor_model_parallel_all_reduce(final_hidden_states) @@ -166,13 +237,13 @@ def _maybe_all_reduce_tensor_model_parallel_impl( direct_register_custom_op(op_name="maybe_all_gather_and_maybe_unpad", op_func=_maybe_all_gather_and_maybe_unpad_impl, - fake_impl=lambda x, label: x, + fake_impl=_maybe_all_gather_and_maybe_unpad_fake, mutates_args=[], dispatch_key="PrivateUse1") direct_register_custom_op(op_name="maybe_pad_and_reduce", op_func=_maybe_pad_and_reduce_impl, - fake_impl=lambda x: x, + fake_impl=_maybe_pad_and_reduce_fake, mutates_args=[], dispatch_key="PrivateUse1") diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index b9eaf5fc9..2218091a9 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -31,7 +31,7 @@ from vllm_ascend.ascend_config import (check_ascend_config, get_ascend_config, init_ascend_config) from vllm_ascend.torchair.utils import (check_torchair_cache_exist, delete_torchair_cache_file) -from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, is_310p, +from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, enable_sp, is_310p, update_aclgraph_sizes) if TYPE_CHECKING: @@ -211,6 +211,21 @@ class NPUPlatform(Platform): # set cudaprah sizes before extending `compilation_config.splitting_ops` vllm_config._set_cudagraph_sizes() + # TODO delete graph size update here when compilation_config.pass_config.enable_sequence_parallelism + # is supported by vllm-ascend. + if vllm_config.parallel_config.tensor_parallel_size > 1 and not vllm_config.model_config.enforce_eager and \ + enable_sp(vllm_config): + original_sizes = compilation_config.cudagraph_capture_sizes + sp_aclgraph_sizes = \ + vllm_config.update_sizes_for_sequence_parallelism(original_sizes) + assert sp_aclgraph_sizes, ( + f"cudagraph_capture_sizes {original_sizes} does not contain" + f"values that are multiples of tp_size " + f"{vllm_config.parallel_config.tensor_parallel_size}") + if len(sp_aclgraph_sizes) != len(original_sizes): + compilation_config.cudagraph_capture_sizes = sp_aclgraph_sizes + vllm_config.compilation_config.init_with_cudagraph_sizes( + sp_aclgraph_sizes) # TODO: Full graph is fully supported later, and the default value will be set to full graph. if compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE: diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 9e64d47db..6cc903aad 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -55,6 +55,7 @@ _PREFETCH_STREAM = None _ASCEND_CUSTOMOP_IS_REIGISTERED = False _DEFAULT_BUFFER_SIZE = 200 _MIN_DP_BUFFER_SIZE = 50 +_IS_MOE_MODEL = None def is_310p(): @@ -609,12 +610,24 @@ def enable_sp(vllm_config=None) -> bool: vllm_config = get_current_vllm_config() return ( vllm_config.compilation_config.pass_config.enable_sequence_parallelism - or envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM) + or envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM1 + # Flash comm 1 should be enabled by env VLLM_ASCEND_ENABLE_FLASHCOMM1 + # We retain the env VLLM_ASCEND_ENABLE_FLASHCOMM here for backward compatibility. + or bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", '0')))) + + +# TODO remove it after vllm has this func +def shared_expert_dp_enabled() -> bool: + return get_ascend_config().enable_shared_expert_dp or enable_sp() def is_moe_model(vllm_config: VllmConfig): - config = vllm_config.model_config.hf_config - return any('experts' in key.lower() for key in config.to_dict()) + global _IS_MOE_MODEL + if _IS_MOE_MODEL is None: + config = vllm_config.model_config.hf_config + _IS_MOE_MODEL = any('experts' in key.lower() + for key in config.to_dict()) + return _IS_MOE_MODEL def weak_ref_tensor(tensor: Any) -> Any: diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 95acbe4c8..69daf02f1 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -20,6 +20,7 @@ import copy import gc import itertools +import math import re import time from collections import defaultdict @@ -128,8 +129,8 @@ from vllm_ascend.spec_decode.interface import SpecDcodeType from vllm_ascend.spec_decode.mtp_proposer import MtpProposer from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, AscendSocVersion, ProfileExecuteDuration, - get_ascend_soc_version, is_310p, is_enable_nz, - lmhead_tp_enable) + enable_sp, get_ascend_soc_version, is_310p, + is_enable_nz, lmhead_tp_enable) from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch if TYPE_CHECKING: @@ -1210,6 +1211,14 @@ class NPUModelRunner(LoRAModelRunnerMixin): # Add padding to the batch size. num_input_tokens = self.vllm_config.pad_for_cudagraph( total_num_scheduled_tokens) + elif self.use_aclgraph and enable_sp(self.vllm_config): + # When using aclgraph, if total_num_scheduled_tokens exceeds the maximum graph size, + # the model will fall back to running its FX graph in eager mode. + # In this case, when sequence parallelism is enabled, we need to pad tokens to align + # with tp_size because pad_size cannot be captured by the FX graph + tp_size = self.vllm_config.parallel_config.tensor_parallel_size + num_input_tokens = math.ceil( + total_num_scheduled_tokens / tp_size) * tp_size else: # Eager mode. num_input_tokens = total_num_scheduled_tokens @@ -1850,7 +1859,10 @@ class NPUModelRunner(LoRAModelRunnerMixin): raise ValueError(f"Unsupported soc_version: {soc_version}") if moe_comm_type == MoECommType.ALLGATHER and with_prefill: - moe_comm_type = MoECommType.NAIVE_MULTICAST + if enable_sp(): + moe_comm_type = MoECommType.ALLGATHER + else: + moe_comm_type = MoECommType.NAIVE_MULTICAST # PanguProMoE only supports allgather if model_type == "PanguProMoE": @@ -2314,6 +2326,12 @@ class NPUModelRunner(LoRAModelRunnerMixin): CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL } + # In multi-DP scenarios, there may be situations where all DP groups are executing dummy runs. + # If sequence parallelism is enabled, it is essential to ensure that num_tokens is divisible by tp_size. + if self.use_aclgraph and enable_sp(self.vllm_config): + tp_size = self.vllm_config.parallel_config.tensor_parallel_size + num_tokens = math.ceil(num_tokens / tp_size) * tp_size + # Padding for DP (num_tokens, num_tokens_across_dp, with_prefill, _) = self._sync_metadata_across_dp(num_tokens, with_prefill, False)