[Feat] Flash comm allgher ep (#3334)

Support flash comm v1(Sequence Parallelism) for Allgather EP.

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: realliujiaxu <realliujiaxu@163.com>
Co-authored-by: zhaozx-cn <zhaozx2116@163.com>
This commit is contained in:
realliujiaxu
2025-10-15 19:36:32 +08:00
committed by GitHub
parent 8abe517870
commit f69a83b7ba
15 changed files with 283 additions and 78 deletions

View File

@ -166,7 +166,7 @@ def test_sp_for_qwen3_moe() -> None:
@pytest.mark.parametrize("enforce_eager", [True, False])
@pytest.mark.parametrize("model", QWEN_DENSE_MODELS)
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE": "1"})
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM": "1"})
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"})
def test_models_distributed_Qwen_Dense_with_flashcomm_v1(model, enforce_eager):
example_prompts = [
"Hello, my name is",

View File

@ -500,9 +500,12 @@ class TestAscendMLAImpl(TestBase):
mock_up_proj.assert_called_once()
mock_npu_fused_infer_attention_score.assert_called_once()
@patch("torch.ops.vllm.maybe_all_gather_and_maybe_unpad")
@patch("vllm_ascend.attention.mla_v1.maybe_npu_prefetch")
def test_mla_preprocess(self, magic_npu_fetch):
def test_mla_preprocess(self, magic_npu_fetch,
mock_maybe_all_gather_and_maybe_unpad):
magic_npu_fetch.return_value = MagicMock()
mock_maybe_all_gather_and_maybe_unpad.side_effect = lambda x, label: x
batch_size = 4
seq_len = 8
hidden_size = 1024

View File

@ -42,9 +42,11 @@ def test_row_parallel_linear(cls, mock_distributed):
assert output[0].shape == (2, 4, 64)
@patch("vllm_ascend.models.layers.mla.get_forward_context")
@patch("torch.ops.vllm.mla_forward")
@patch("torch_npu.npu_rms_norm")
def test_custom_deepseek_v2_mla_attention(mock_rms_norm, mock_mla_forward,
mock_forward_context,
mock_distributed, base_config):
mock_rms_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128))
# Make a fake ascend config because of the AscendLinearBase
@ -54,6 +56,9 @@ def test_custom_deepseek_v2_mla_attention(mock_rms_norm, mock_mla_forward,
vllm_config.parallel_config.tensor_parallel_size = 1
vllm_config.kv_transfer_config = None
ascend_config.init_ascend_config(vllm_config)
dummy_forward_context = MagicMock()
dummy_forward_context.sp_enabled = False
mock_forward_context.return_value = dummy_forward_context
attn = CustomDeepseekV2MLAAttention(config=base_config,
hidden_size=128,

View File

@ -11,7 +11,7 @@ from vllm.forward_context import (BatchDescriptor, get_forward_context,
set_forward_context)
import vllm_ascend.envs as envs_ascend
from vllm_ascend.utils import enable_sp
from vllm_ascend.utils import enable_sp, is_moe_model
if TYPE_CHECKING:
from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod
@ -112,15 +112,20 @@ def set_ascend_forward_context(
# Currently, it is an empirical value. In normal scenarios, if the concurrency exceeds this threshold,
# the performance benefits can be maximized. Conversely, if the concurrency is below the threshold,
# the performance may degrade due to the switching of communication methods.
sp_enabled = enable_sp(vllm_config) and \
tp_world_size > 1 and \
num_tokens is not None and num_tokens > 1000
if is_moe_model(vllm_config):
sp_enabled = enable_sp(vllm_config) and \
tp_world_size > 1
else:
sp_enabled = enable_sp(vllm_config) and \
tp_world_size > 1 and \
num_tokens is not None and num_tokens > 1000
if sp_enabled:
pad_size = (tp_world_size -
(num_tokens % tp_world_size)) % tp_world_size
forward_context.pad_size = pad_size
forward_context.sp_enabled = sp_enabled
forward_context.num_tokens = num_tokens
# set this for rope forward_oot using
forward_context.is_first_layer = True
@ -169,8 +174,14 @@ def set_ascend_forward_context(
dp_world_size = get_dp_group().world_size
if dp_world_size > 1 and forward_context.dp_metadata is not None:
max_tokens_across_dp = forward_context.dp_metadata.max_tokens_across_dp_cpu.item(
)
max_tokens_across_dp = \
forward_context.dp_metadata.max_tokens_across_dp_cpu.item()
if sp_enabled:
padded_length = (max_tokens_across_dp + tp_world_size -
1) // tp_world_size * tp_world_size
pad_size = padded_length - num_tokens
forward_context.padded_length = padded_length
forward_context.pad_size = pad_size
else:
max_tokens_across_dp = num_tokens

View File

@ -9,7 +9,7 @@ from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata,
MLAAttentionImpl)
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
@ -1128,10 +1128,11 @@ class AscendMLAImpl(MLAAttentionImpl):
q_c = hidden_states
kv_no_split = self.kv_a_proj_with_mqa(hidden_states)[0]
# Process for shared_expert_dp
if need_gather_q_kv:
q_c = get_tp_group().all_gather(q_c, 0)
kv_no_split = get_tp_group().all_gather(kv_no_split, 0)
# Process for Flash Comm V1
q_c = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
q_c, need_gather_q_kv)
kv_no_split = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
kv_no_split, need_gather_q_kv)
decode_preprocess_res = None
prefill_preprocess_res = None
if has_prefill:
@ -1200,8 +1201,7 @@ class AscendMLAImpl(MLAAttentionImpl):
num_decode_tokens = attn_metadata.num_decode_tokens
# Inputs and outputs may be padded for CUDA graphs
output_padded = output
output = output[:num_actual_tokens, ...]
o_proj_input_shape = (num_actual_tokens,
o_proj_input_shape = (get_forward_context().num_tokens,
self.num_heads * self.v_head_dim)
o_proj_input = torch.empty(o_proj_input_shape,
dtype=hidden_states.dtype,
@ -1248,7 +1248,8 @@ class AscendMLAImpl(MLAAttentionImpl):
o_proj_input[num_decode_tokens:] = output_prefill
current_ms_metadata.after_comm_event.record()
else:
o_proj_input[num_decode_tokens:] = output_prefill
o_proj_input[
num_decode_tokens:num_actual_tokens] = output_prefill
# O proj
current_ms_metadata = get_multistream_comm_context()
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024
@ -1258,20 +1259,14 @@ class AscendMLAImpl(MLAAttentionImpl):
max_size=MAX_O_PROJ_PREFETCH_SIZE,
enabled=self.enable_prefetch)
output[...] = self.o_proj(
o_proj_input,
is_prefill=prefill_preprocess_res is not None,
is_force_scatter=self.enable_shared_expert_dp)[0]
output[...] = self.o_proj(o_proj_input)[0]
else:
with torch.npu.stream(current_ms_metadata.comm_stream):
maybe_npu_prefetch(inputs=self.o_proj.weight,
dependency=o_proj_input,
max_size=MAX_O_PROJ_PREFETCH_SIZE,
enabled=self.enable_prefetch)
output[...] = self.o_proj(
o_proj_input,
is_prefill=prefill_preprocess_res is not None,
is_force_scatter=self.enable_shared_expert_dp)[0]
output[...] = self.o_proj(o_proj_input)[0]
current_ms_metadata.after_comm_event.record()
del o_proj_input

View File

@ -133,8 +133,8 @@ env_variables: Dict[str, Callable[[], Any]] = {
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE", '0'))),
# Whether to enable FlashComm optimization when tensor parallel is enabled.
# This feature will get better performance when concurrency is large.
"VLLM_ASCEND_ENABLE_FLASHCOMM":
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", '0'))),
"VLLM_ASCEND_ENABLE_FLASHCOMM1":
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM1", '0'))),
# Whether to enable MLP weight prefetch, only used in small concurrency.
"VLLM_ASCEND_ENABLE_PREFETCH_MLP":
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_PREFETCH_MLP", '0'))),

View File

@ -300,12 +300,11 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.kv_b_proj")
self.o_proj = CustomDeepseekV2RowParallelLinear(
self.num_heads * self.v_head_dim,
self.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj")
self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim,
self.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj")
if rope_scaling:
rope_scaling["rope_type"] = 'deepseek_yarn'

View File

@ -122,19 +122,8 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttention):
hidden_states: torch.Tensor,
kv_cache: Optional[torch.Tensor] = None,
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
num_tokens = hidden_states.shape[0]
need_gather_q_kv = False
if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers:
# Simulate all gather to calculate output shape
num_tokens = num_tokens * self.tp_size
need_gather_q_kv = True
if not self.enable_shared_expert_dp or self.debug_layer_idx < self.first_k_dense_replace:
output_shape = hidden_states.shape
else:
rows = num_tokens // self.tp_size
if num_tokens % self.tp_size:
rows += 1
output_shape = (rows, hidden_states.shape[1])
need_gather_q_kv = get_forward_context().sp_enabled
output_shape = hidden_states.shape
# FIXME: This does not seem right, should make sure the buffer is fixed
output = torch.empty(output_shape,
dtype=hidden_states.dtype,

View File

@ -38,8 +38,9 @@ from vllm_ascend.eplb.core.eplb_utils import (determine_default_expert_map,
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
from vllm_ascend.ops.moe.experts_selector import select_experts
from vllm_ascend.ops.moe.moe_comm_method import setup_moe_comm_method
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, is_310p, is_enable_nz,
npu_stream_switch)
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, enable_sp, is_310p,
is_enable_nz, npu_stream_switch,
shared_expert_dp_enabled)
class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
@ -417,6 +418,10 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
self.multistream_overlap_shared_expert = ascend_config.multistream_overlap_shared_expert
if self.multistream_overlap_shared_expert:
self.shared_expert_stream = torch.npu.Stream()
if enable_sp():
logger.info_once(
"Sequence parallelism is enabled, shared experts are replicated for best performance."
)
def forward(
self,
@ -444,7 +449,8 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
forward_context = get_forward_context()
moe_comm_type = forward_context.moe_comm_type
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2}:
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2} \
and not shared_expert_dp_enabled():
shared_out = tensor_model_parallel_all_reduce(shared_out)
fused_output = AscendFusedMoE.forward_impl(
self,

View File

@ -49,7 +49,7 @@ from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group,
get_otp_group)
from vllm_ascend.utils import (dense_optim_enable, enable_sp,
matmul_allreduce_enable, mlp_tp_enable,
oproj_tp_enable)
oproj_tp_enable, shared_expert_dp_enabled)
class CustomLinearOp:
@ -418,7 +418,8 @@ def _get_row_parallel_op(
def get_parallel_op(disable_tp, prefix, layer, direct):
if disable_tp:
if disable_tp or ("shared_experts" in prefix
and shared_expert_dp_enabled()):
return None, 0, 1
custom_op: Optional[Union[MLPColumnParallelOp, SequenceColumnParallelOp,
MLPRowParallelOp, OProjRowParallelOp,

View File

@ -27,7 +27,7 @@ from vllm.distributed.parallel_state import (
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
from vllm_ascend.utils import get_rm_router_logits_state
from vllm_ascend.utils import enable_sp, get_rm_router_logits_state
class FusedMoEPrepareAndFinalize(ABC):
@ -198,7 +198,7 @@ class FusedMoEPrepareAndFinalizeWithAll2All(FusedMoEPrepareAndFinalize):
class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalizeWithAll2All):
"""
MoE communication strategy using MC2, which is based on All2All. Hence, it inherits
All2All and share the same finalize method.
All2All and share the same finalize method.
Designed for Ascend or environments requiring explicit padding and slicing control.
Relies on `mc2_mask` and `padded_num_tokens` from forward_context for alignment.
"""
@ -277,9 +277,24 @@ class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalizeWithAll2All):
class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
"""
MoE communication strategy using All-Gather + Reduce-Scatter.
Designed for DP > 1: gather inputs across DP ranks before MoE, scatter outputs after.
Uses `max_tokens_across_dp` from forward_context for padding alignment.
MoE communication strategy using All-Gather + Reduce-Scatter on EP group.
There are two sets of prepare and finalize:
1. _prepare_with_dp_group/_finalize_with_dp_group: When sequence parallelism is not enabled,
we gather inputs across DP ranks before MoE, scatter outputs after.
The communication and calculation process is as follows (AG, AR and RS
are abbreviations for All-Gather, All-Reduce and Reduce-Scatter, respectively):
Attn → TP AR → DP AG → MoE → DP RS → TP AR
2. _prepare_with_ep_group/_finalize_with_ep_group: When sequence parallelism is enabled,
the above process becomes:
TP AG → Attn → TP RS → TP AG → DP AG → MoE → DP RS → TP RS
This strategy further combines TP AG + DP AG into EP All-Gather and TP RS + DP RS
into EP Reduce-Scatter to improve communication performance. The optimized process is as follows:
TP AG → Attn → TP RS → EP AG → MoE → EP RS
"""
def prepare(
@ -289,6 +304,42 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False,
gate=None
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]:
"""
Preparation steps:
AllGather hidden_states and router_logits to form global tensors.
Returns:
Tuple of (global_hidden_states, global_router_logits, None)
"""
if enable_sp():
return self._prepare_with_ep_group(hidden_states, router_logits)
return self._prepare_with_dp_group(hidden_states, router_logits,
enable_shared_expert_dp,
replace_allreduce, gate)
def _prepare_with_ep_group(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]:
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
hidden_states, True, True)
router_logits = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
router_logits, True, True)
return hidden_states, router_logits, None, None
def _prepare_with_dp_group(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False,
gate=None
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]:
"""
@ -301,7 +352,6 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
Tuple of (global_hidden_states, global_router_logits, None, None)
"""
self.enable_shared_expert_dp = enable_shared_expert_dp
if self.moe_config.dp_size > 1:
forward_context = get_forward_context()
max_tokens_across_dp = forward_context.max_tokens_across_dp
@ -323,7 +373,6 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
else:
router_logits = self.moe_config.dp_group.all_gather(
router_logits, 0)
return hidden_states, router_logits, None, None
def finalize(self,
@ -331,6 +380,36 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
reduce_results: bool,
context_metadata: Optional[dict] = None) -> torch.Tensor:
"""
Finalization steps:
Reduce Scatter hidden states.
Returns:
Tensor with shape [local_num_tokens, hidden_size]
"""
if enable_sp():
return self._finalize_with_ep_group(hidden_states)
return self._finalize_with_dp_group(hidden_states, reduce_results)
def _finalize_with_ep_group(self,
hidden_states: torch.Tensor) -> torch.Tensor:
"""
Argument `reduce_results` is not needed in this func. Given sequence parallelism is enabled:
1. Reduce_results is False usually happens when models have shared experts and need to
allreduce hidden states after results of shared experts and routed experts are added in FusedMoe.
We do reduce scatter for hidden states here, then skip allreudce in FusedMoe and add it to the
result of shared experts.
2 Reduce_results is True usually happens when model has no shared experts. We still do reduce scatter
here, then skip allreudce in FusedMoe.
"""
hidden_states = torch.ops.vllm.maybe_pad_and_reduce(
hidden_states, True)
return hidden_states
def _finalize_with_dp_group(self, hidden_states: torch.Tensor,
reduce_results: bool) -> torch.Tensor:
"""
Finalization steps:
1. If DP > 1 and not shared expert, reduce-scatter output across DP group.
2. Slice to original local token count.

View File

@ -1,7 +1,9 @@
import torch
import torch.nn.functional as F
import torch_npu
from vllm.distributed import (tensor_model_parallel_all_gather,
from vllm.distributed import (get_dp_group, get_ep_group,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
tensor_model_parallel_reduce_scatter)
from vllm.forward_context import get_forward_context
@ -13,8 +15,10 @@ from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
from vllm_ascend.utils import npu_stream_switch, prefetch_stream
def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor,
label: bool) -> torch.Tensor:
def _maybe_all_gather_and_maybe_unpad_impl(
x: torch.Tensor,
label: bool,
is_ep_comm: bool = False) -> torch.Tensor:
try:
forward_context = get_forward_context()
except AssertionError:
@ -22,27 +26,66 @@ def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor,
sp_enabled = forward_context.sp_enabled
if sp_enabled and label:
x = tensor_model_parallel_all_gather(x, 0)
pad_size = forward_context.pad_size
if pad_size > 0:
x = x[:-pad_size, :]
dp_metadata = forward_context.dp_metadata
if dp_metadata is None or not is_ep_comm:
x = tensor_model_parallel_all_gather(x, 0)
pad_size = forward_context.pad_size
if pad_size > 0:
x = x[:-pad_size, :]
else:
x = get_ep_group().all_gather(x, 0)
# unpad
num_tokens_across_dp_cpu = dp_metadata.num_tokens_across_dp_cpu
result = torch.empty(
(num_tokens_across_dp_cpu.sum(), *x.shape[1:]),
device=x.device,
dtype=x.dtype)
dp_size = get_dp_group().world_size
x = x.view(dp_size, forward_context.padded_length, *x.shape[1:])
offset = 0
for idx in range(dp_size):
num_tokens_dp = num_tokens_across_dp_cpu[idx]
result[offset:offset +
num_tokens_dp, :] = x[idx, :num_tokens_dp, :]
offset += num_tokens_dp
x = result
return x
def _maybe_pad_and_reduce_impl(x: torch.Tensor) -> torch.Tensor:
def _maybe_pad_and_reduce_impl(x: torch.Tensor,
is_ep_comm: bool = False) -> torch.Tensor:
try:
forward_context = get_forward_context()
except AssertionError:
return tensor_model_parallel_all_reduce(x)
sp_enabled = forward_context.sp_enabled
if sp_enabled:
if not forward_context.sp_enabled:
return tensor_model_parallel_all_reduce(x)
dp_metadata = forward_context.dp_metadata
if dp_metadata is None or not is_ep_comm:
pad_size = forward_context.pad_size
if pad_size > 0:
x = F.pad(x, (0, 0, 0, pad_size))
return tensor_model_parallel_reduce_scatter(x, 0)
else:
return tensor_model_parallel_all_reduce(x)
# padding
dp_size = get_dp_group().world_size
num_tokens_across_dp_cpu = \
get_forward_context().dp_metadata.num_tokens_across_dp_cpu
padded_x = torch.empty(
(dp_size, forward_context.padded_length, *x.shape[1:]),
device=x.device,
dtype=x.dtype)
offset = 0
for idx in range(dp_size):
num_tokens_dp = num_tokens_across_dp_cpu[idx]
padded_x[idx, :num_tokens_dp] = x[offset:offset + num_tokens_dp]
offset += num_tokens_dp
return get_ep_group().reduce_scatter(padded_x.view(-1, *x.shape[1:]),
0)
def _maybe_prefetch_mlp_gate_up_proj_impl(x_dependency: torch.Tensor,
@ -71,6 +114,33 @@ def _maybe_prefetch_mlp_gate_up_proj_impl(x_dependency: torch.Tensor,
return
def _maybe_all_gather_and_maybe_unpad_fake(
x: torch.Tensor,
label: bool,
is_ep_comm: bool = False) -> torch.Tensor:
if get_forward_context().sp_enabled and label:
return torch.empty(
(x.shape[0] * get_tensor_model_parallel_world_size(),
*x.shape[1:]),
device=x.device,
dtype=x.dtype)
return x
def _maybe_pad_and_reduce_fake(x: torch.Tensor,
is_ep_comm: bool = False) -> torch.Tensor:
if get_forward_context().sp_enabled:
return torch.empty(
(x.shape[0] // get_tensor_model_parallel_world_size(),
*x.shape[1:]),
device=x.device,
dtype=x.dtype)
return x
def _maybe_prefetch_mlp_gate_up_proj_impl_fake(x_dependency: torch.Tensor,
prefix: str) -> None:
return
@ -158,7 +228,8 @@ def _maybe_all_reduce_tensor_model_parallel_impl(
final_hidden_states: torch.Tensor) -> torch.Tensor:
forward_context = get_forward_context()
moe_comm_type = forward_context.moe_comm_type
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2}:
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2
} or forward_context.sp_enabled:
return final_hidden_states
else:
return tensor_model_parallel_all_reduce(final_hidden_states)
@ -166,13 +237,13 @@ def _maybe_all_reduce_tensor_model_parallel_impl(
direct_register_custom_op(op_name="maybe_all_gather_and_maybe_unpad",
op_func=_maybe_all_gather_and_maybe_unpad_impl,
fake_impl=lambda x, label: x,
fake_impl=_maybe_all_gather_and_maybe_unpad_fake,
mutates_args=[],
dispatch_key="PrivateUse1")
direct_register_custom_op(op_name="maybe_pad_and_reduce",
op_func=_maybe_pad_and_reduce_impl,
fake_impl=lambda x: x,
fake_impl=_maybe_pad_and_reduce_fake,
mutates_args=[],
dispatch_key="PrivateUse1")

View File

@ -31,7 +31,7 @@ from vllm_ascend.ascend_config import (check_ascend_config, get_ascend_config,
init_ascend_config)
from vllm_ascend.torchair.utils import (check_torchair_cache_exist,
delete_torchair_cache_file)
from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, is_310p,
from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, enable_sp, is_310p,
update_aclgraph_sizes)
if TYPE_CHECKING:
@ -211,6 +211,21 @@ class NPUPlatform(Platform):
# set cudaprah sizes before extending `compilation_config.splitting_ops`
vllm_config._set_cudagraph_sizes()
# TODO delete graph size update here when compilation_config.pass_config.enable_sequence_parallelism
# is supported by vllm-ascend.
if vllm_config.parallel_config.tensor_parallel_size > 1 and not vllm_config.model_config.enforce_eager and \
enable_sp(vllm_config):
original_sizes = compilation_config.cudagraph_capture_sizes
sp_aclgraph_sizes = \
vllm_config.update_sizes_for_sequence_parallelism(original_sizes)
assert sp_aclgraph_sizes, (
f"cudagraph_capture_sizes {original_sizes} does not contain"
f"values that are multiples of tp_size "
f"{vllm_config.parallel_config.tensor_parallel_size}")
if len(sp_aclgraph_sizes) != len(original_sizes):
compilation_config.cudagraph_capture_sizes = sp_aclgraph_sizes
vllm_config.compilation_config.init_with_cudagraph_sizes(
sp_aclgraph_sizes)
# TODO: Full graph is fully supported later, and the default value will be set to full graph.
if compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE:

View File

@ -55,6 +55,7 @@ _PREFETCH_STREAM = None
_ASCEND_CUSTOMOP_IS_REIGISTERED = False
_DEFAULT_BUFFER_SIZE = 200
_MIN_DP_BUFFER_SIZE = 50
_IS_MOE_MODEL = None
def is_310p():
@ -609,12 +610,24 @@ def enable_sp(vllm_config=None) -> bool:
vllm_config = get_current_vllm_config()
return (
vllm_config.compilation_config.pass_config.enable_sequence_parallelism
or envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM)
or envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM1
# Flash comm 1 should be enabled by env VLLM_ASCEND_ENABLE_FLASHCOMM1
# We retain the env VLLM_ASCEND_ENABLE_FLASHCOMM here for backward compatibility.
or bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", '0'))))
# TODO remove it after vllm has this func
def shared_expert_dp_enabled() -> bool:
return get_ascend_config().enable_shared_expert_dp or enable_sp()
def is_moe_model(vllm_config: VllmConfig):
config = vllm_config.model_config.hf_config
return any('experts' in key.lower() for key in config.to_dict())
global _IS_MOE_MODEL
if _IS_MOE_MODEL is None:
config = vllm_config.model_config.hf_config
_IS_MOE_MODEL = any('experts' in key.lower()
for key in config.to_dict())
return _IS_MOE_MODEL
def weak_ref_tensor(tensor: Any) -> Any:

View File

@ -20,6 +20,7 @@
import copy
import gc
import itertools
import math
import re
import time
from collections import defaultdict
@ -128,8 +129,8 @@ from vllm_ascend.spec_decode.interface import SpecDcodeType
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
AscendSocVersion, ProfileExecuteDuration,
get_ascend_soc_version, is_310p, is_enable_nz,
lmhead_tp_enable)
enable_sp, get_ascend_soc_version, is_310p,
is_enable_nz, lmhead_tp_enable)
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
if TYPE_CHECKING:
@ -1210,6 +1211,14 @@ class NPUModelRunner(LoRAModelRunnerMixin):
# Add padding to the batch size.
num_input_tokens = self.vllm_config.pad_for_cudagraph(
total_num_scheduled_tokens)
elif self.use_aclgraph and enable_sp(self.vllm_config):
# When using aclgraph, if total_num_scheduled_tokens exceeds the maximum graph size,
# the model will fall back to running its FX graph in eager mode.
# In this case, when sequence parallelism is enabled, we need to pad tokens to align
# with tp_size because pad_size cannot be captured by the FX graph
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
num_input_tokens = math.ceil(
total_num_scheduled_tokens / tp_size) * tp_size
else:
# Eager mode.
num_input_tokens = total_num_scheduled_tokens
@ -1850,7 +1859,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
raise ValueError(f"Unsupported soc_version: {soc_version}")
if moe_comm_type == MoECommType.ALLGATHER and with_prefill:
moe_comm_type = MoECommType.NAIVE_MULTICAST
if enable_sp():
moe_comm_type = MoECommType.ALLGATHER
else:
moe_comm_type = MoECommType.NAIVE_MULTICAST
# PanguProMoE only supports allgather
if model_type == "PanguProMoE":
@ -2314,6 +2326,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL
}
# In multi-DP scenarios, there may be situations where all DP groups are executing dummy runs.
# If sequence parallelism is enabled, it is essential to ensure that num_tokens is divisible by tp_size.
if self.use_aclgraph and enable_sp(self.vllm_config):
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
num_tokens = math.ceil(num_tokens / tp_size) * tp_size
# Padding for DP
(num_tokens, num_tokens_across_dp, with_prefill,
_) = self._sync_metadata_across_dp(num_tokens, with_prefill, False)