mirror of
https://github.com/vllm-project/vllm-ascend.git
synced 2025-10-20 13:43:53 +08:00
[Feat] Flash comm allgher ep (#3334)
Support flash comm v1(Sequence Parallelism) for Allgather EP. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: realliujiaxu <realliujiaxu@163.com> Co-authored-by: zhaozx-cn <zhaozx2116@163.com>
This commit is contained in:
@ -166,7 +166,7 @@ def test_sp_for_qwen3_moe() -> None:
|
||||
@pytest.mark.parametrize("enforce_eager", [True, False])
|
||||
@pytest.mark.parametrize("model", QWEN_DENSE_MODELS)
|
||||
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_DENSE_OPTIMIZE": "1"})
|
||||
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM": "1"})
|
||||
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"})
|
||||
def test_models_distributed_Qwen_Dense_with_flashcomm_v1(model, enforce_eager):
|
||||
example_prompts = [
|
||||
"Hello, my name is",
|
||||
|
@ -500,9 +500,12 @@ class TestAscendMLAImpl(TestBase):
|
||||
mock_up_proj.assert_called_once()
|
||||
mock_npu_fused_infer_attention_score.assert_called_once()
|
||||
|
||||
@patch("torch.ops.vllm.maybe_all_gather_and_maybe_unpad")
|
||||
@patch("vllm_ascend.attention.mla_v1.maybe_npu_prefetch")
|
||||
def test_mla_preprocess(self, magic_npu_fetch):
|
||||
def test_mla_preprocess(self, magic_npu_fetch,
|
||||
mock_maybe_all_gather_and_maybe_unpad):
|
||||
magic_npu_fetch.return_value = MagicMock()
|
||||
mock_maybe_all_gather_and_maybe_unpad.side_effect = lambda x, label: x
|
||||
batch_size = 4
|
||||
seq_len = 8
|
||||
hidden_size = 1024
|
||||
|
@ -42,9 +42,11 @@ def test_row_parallel_linear(cls, mock_distributed):
|
||||
assert output[0].shape == (2, 4, 64)
|
||||
|
||||
|
||||
@patch("vllm_ascend.models.layers.mla.get_forward_context")
|
||||
@patch("torch.ops.vllm.mla_forward")
|
||||
@patch("torch_npu.npu_rms_norm")
|
||||
def test_custom_deepseek_v2_mla_attention(mock_rms_norm, mock_mla_forward,
|
||||
mock_forward_context,
|
||||
mock_distributed, base_config):
|
||||
mock_rms_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128))
|
||||
# Make a fake ascend config because of the AscendLinearBase
|
||||
@ -54,6 +56,9 @@ def test_custom_deepseek_v2_mla_attention(mock_rms_norm, mock_mla_forward,
|
||||
vllm_config.parallel_config.tensor_parallel_size = 1
|
||||
vllm_config.kv_transfer_config = None
|
||||
ascend_config.init_ascend_config(vllm_config)
|
||||
dummy_forward_context = MagicMock()
|
||||
dummy_forward_context.sp_enabled = False
|
||||
mock_forward_context.return_value = dummy_forward_context
|
||||
|
||||
attn = CustomDeepseekV2MLAAttention(config=base_config,
|
||||
hidden_size=128,
|
||||
|
@ -11,7 +11,7 @@ from vllm.forward_context import (BatchDescriptor, get_forward_context,
|
||||
set_forward_context)
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.utils import enable_sp
|
||||
from vllm_ascend.utils import enable_sp, is_moe_model
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod
|
||||
@ -112,15 +112,20 @@ def set_ascend_forward_context(
|
||||
# Currently, it is an empirical value. In normal scenarios, if the concurrency exceeds this threshold,
|
||||
# the performance benefits can be maximized. Conversely, if the concurrency is below the threshold,
|
||||
# the performance may degrade due to the switching of communication methods.
|
||||
sp_enabled = enable_sp(vllm_config) and \
|
||||
tp_world_size > 1 and \
|
||||
num_tokens is not None and num_tokens > 1000
|
||||
if is_moe_model(vllm_config):
|
||||
sp_enabled = enable_sp(vllm_config) and \
|
||||
tp_world_size > 1
|
||||
else:
|
||||
sp_enabled = enable_sp(vllm_config) and \
|
||||
tp_world_size > 1 and \
|
||||
num_tokens is not None and num_tokens > 1000
|
||||
|
||||
if sp_enabled:
|
||||
pad_size = (tp_world_size -
|
||||
(num_tokens % tp_world_size)) % tp_world_size
|
||||
forward_context.pad_size = pad_size
|
||||
forward_context.sp_enabled = sp_enabled
|
||||
forward_context.num_tokens = num_tokens
|
||||
|
||||
# set this for rope forward_oot using
|
||||
forward_context.is_first_layer = True
|
||||
@ -169,8 +174,14 @@ def set_ascend_forward_context(
|
||||
|
||||
dp_world_size = get_dp_group().world_size
|
||||
if dp_world_size > 1 and forward_context.dp_metadata is not None:
|
||||
max_tokens_across_dp = forward_context.dp_metadata.max_tokens_across_dp_cpu.item(
|
||||
)
|
||||
max_tokens_across_dp = \
|
||||
forward_context.dp_metadata.max_tokens_across_dp_cpu.item()
|
||||
if sp_enabled:
|
||||
padded_length = (max_tokens_across_dp + tp_world_size -
|
||||
1) // tp_world_size * tp_world_size
|
||||
pad_size = padded_length - num_tokens
|
||||
forward_context.padded_length = padded_length
|
||||
forward_context.pad_size = pad_size
|
||||
else:
|
||||
max_tokens_across_dp = num_tokens
|
||||
|
||||
|
@ -9,7 +9,7 @@ from vllm.attention.backends.abstract import (AttentionBackend,
|
||||
AttentionMetadata,
|
||||
MLAAttentionImpl)
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.model_executor.layers.linear import (LinearBase,
|
||||
UnquantizedLinearMethod)
|
||||
@ -1128,10 +1128,11 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
q_c = hidden_states
|
||||
|
||||
kv_no_split = self.kv_a_proj_with_mqa(hidden_states)[0]
|
||||
# Process for shared_expert_dp
|
||||
if need_gather_q_kv:
|
||||
q_c = get_tp_group().all_gather(q_c, 0)
|
||||
kv_no_split = get_tp_group().all_gather(kv_no_split, 0)
|
||||
# Process for Flash Comm V1
|
||||
q_c = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
q_c, need_gather_q_kv)
|
||||
kv_no_split = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
kv_no_split, need_gather_q_kv)
|
||||
decode_preprocess_res = None
|
||||
prefill_preprocess_res = None
|
||||
if has_prefill:
|
||||
@ -1200,8 +1201,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
# Inputs and outputs may be padded for CUDA graphs
|
||||
output_padded = output
|
||||
output = output[:num_actual_tokens, ...]
|
||||
o_proj_input_shape = (num_actual_tokens,
|
||||
o_proj_input_shape = (get_forward_context().num_tokens,
|
||||
self.num_heads * self.v_head_dim)
|
||||
o_proj_input = torch.empty(o_proj_input_shape,
|
||||
dtype=hidden_states.dtype,
|
||||
@ -1248,7 +1248,8 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
o_proj_input[num_decode_tokens:] = output_prefill
|
||||
current_ms_metadata.after_comm_event.record()
|
||||
else:
|
||||
o_proj_input[num_decode_tokens:] = output_prefill
|
||||
o_proj_input[
|
||||
num_decode_tokens:num_actual_tokens] = output_prefill
|
||||
# O proj
|
||||
current_ms_metadata = get_multistream_comm_context()
|
||||
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024
|
||||
@ -1258,20 +1259,14 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
max_size=MAX_O_PROJ_PREFETCH_SIZE,
|
||||
enabled=self.enable_prefetch)
|
||||
|
||||
output[...] = self.o_proj(
|
||||
o_proj_input,
|
||||
is_prefill=prefill_preprocess_res is not None,
|
||||
is_force_scatter=self.enable_shared_expert_dp)[0]
|
||||
output[...] = self.o_proj(o_proj_input)[0]
|
||||
else:
|
||||
with torch.npu.stream(current_ms_metadata.comm_stream):
|
||||
maybe_npu_prefetch(inputs=self.o_proj.weight,
|
||||
dependency=o_proj_input,
|
||||
max_size=MAX_O_PROJ_PREFETCH_SIZE,
|
||||
enabled=self.enable_prefetch)
|
||||
output[...] = self.o_proj(
|
||||
o_proj_input,
|
||||
is_prefill=prefill_preprocess_res is not None,
|
||||
is_force_scatter=self.enable_shared_expert_dp)[0]
|
||||
output[...] = self.o_proj(o_proj_input)[0]
|
||||
current_ms_metadata.after_comm_event.record()
|
||||
del o_proj_input
|
||||
|
||||
|
@ -133,8 +133,8 @@ env_variables: Dict[str, Callable[[], Any]] = {
|
||||
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_MATMUL_ALLREDUCE", '0'))),
|
||||
# Whether to enable FlashComm optimization when tensor parallel is enabled.
|
||||
# This feature will get better performance when concurrency is large.
|
||||
"VLLM_ASCEND_ENABLE_FLASHCOMM":
|
||||
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", '0'))),
|
||||
"VLLM_ASCEND_ENABLE_FLASHCOMM1":
|
||||
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM1", '0'))),
|
||||
# Whether to enable MLP weight prefetch, only used in small concurrency.
|
||||
"VLLM_ASCEND_ENABLE_PREFETCH_MLP":
|
||||
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_PREFETCH_MLP", '0'))),
|
||||
|
@ -300,12 +300,11 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.kv_b_proj")
|
||||
self.o_proj = CustomDeepseekV2RowParallelLinear(
|
||||
self.num_heads * self.v_head_dim,
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj")
|
||||
self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim,
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj")
|
||||
|
||||
if rope_scaling:
|
||||
rope_scaling["rope_type"] = 'deepseek_yarn'
|
||||
|
@ -122,19 +122,8 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttention):
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: Optional[torch.Tensor] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
|
||||
num_tokens = hidden_states.shape[0]
|
||||
need_gather_q_kv = False
|
||||
if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers:
|
||||
# Simulate all gather to calculate output shape
|
||||
num_tokens = num_tokens * self.tp_size
|
||||
need_gather_q_kv = True
|
||||
if not self.enable_shared_expert_dp or self.debug_layer_idx < self.first_k_dense_replace:
|
||||
output_shape = hidden_states.shape
|
||||
else:
|
||||
rows = num_tokens // self.tp_size
|
||||
if num_tokens % self.tp_size:
|
||||
rows += 1
|
||||
output_shape = (rows, hidden_states.shape[1])
|
||||
need_gather_q_kv = get_forward_context().sp_enabled
|
||||
output_shape = hidden_states.shape
|
||||
# FIXME: This does not seem right, should make sure the buffer is fixed
|
||||
output = torch.empty(output_shape,
|
||||
dtype=hidden_states.dtype,
|
||||
|
@ -38,8 +38,9 @@ from vllm_ascend.eplb.core.eplb_utils import (determine_default_expert_map,
|
||||
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
|
||||
from vllm_ascend.ops.moe.experts_selector import select_experts
|
||||
from vllm_ascend.ops.moe.moe_comm_method import setup_moe_comm_method
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, is_310p, is_enable_nz,
|
||||
npu_stream_switch)
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, enable_sp, is_310p,
|
||||
is_enable_nz, npu_stream_switch,
|
||||
shared_expert_dp_enabled)
|
||||
|
||||
|
||||
class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
@ -417,6 +418,10 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
|
||||
self.multistream_overlap_shared_expert = ascend_config.multistream_overlap_shared_expert
|
||||
if self.multistream_overlap_shared_expert:
|
||||
self.shared_expert_stream = torch.npu.Stream()
|
||||
if enable_sp():
|
||||
logger.info_once(
|
||||
"Sequence parallelism is enabled, shared experts are replicated for best performance."
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -444,7 +449,8 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
|
||||
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
|
||||
forward_context = get_forward_context()
|
||||
moe_comm_type = forward_context.moe_comm_type
|
||||
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2}:
|
||||
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2} \
|
||||
and not shared_expert_dp_enabled():
|
||||
shared_out = tensor_model_parallel_all_reduce(shared_out)
|
||||
fused_output = AscendFusedMoE.forward_impl(
|
||||
self,
|
||||
|
@ -49,7 +49,7 @@ from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group,
|
||||
get_otp_group)
|
||||
from vllm_ascend.utils import (dense_optim_enable, enable_sp,
|
||||
matmul_allreduce_enable, mlp_tp_enable,
|
||||
oproj_tp_enable)
|
||||
oproj_tp_enable, shared_expert_dp_enabled)
|
||||
|
||||
|
||||
class CustomLinearOp:
|
||||
@ -418,7 +418,8 @@ def _get_row_parallel_op(
|
||||
|
||||
|
||||
def get_parallel_op(disable_tp, prefix, layer, direct):
|
||||
if disable_tp:
|
||||
if disable_tp or ("shared_experts" in prefix
|
||||
and shared_expert_dp_enabled()):
|
||||
return None, 0, 1
|
||||
custom_op: Optional[Union[MLPColumnParallelOp, SequenceColumnParallelOp,
|
||||
MLPRowParallelOp, OProjRowParallelOp,
|
||||
|
@ -27,7 +27,7 @@ from vllm.distributed.parallel_state import (
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
||||
|
||||
from vllm_ascend.utils import get_rm_router_logits_state
|
||||
from vllm_ascend.utils import enable_sp, get_rm_router_logits_state
|
||||
|
||||
|
||||
class FusedMoEPrepareAndFinalize(ABC):
|
||||
@ -198,7 +198,7 @@ class FusedMoEPrepareAndFinalizeWithAll2All(FusedMoEPrepareAndFinalize):
|
||||
class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalizeWithAll2All):
|
||||
"""
|
||||
MoE communication strategy using MC2, which is based on All2All. Hence, it inherits
|
||||
All2All and share the same finalize method.
|
||||
All2All and share the same finalize method.
|
||||
Designed for Ascend or environments requiring explicit padding and slicing control.
|
||||
Relies on `mc2_mask` and `padded_num_tokens` from forward_context for alignment.
|
||||
"""
|
||||
@ -277,9 +277,24 @@ class FusedMoEPrepareAndFinalizeWithMC2(FusedMoEPrepareAndFinalizeWithAll2All):
|
||||
|
||||
class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
|
||||
"""
|
||||
MoE communication strategy using All-Gather + Reduce-Scatter.
|
||||
Designed for DP > 1: gather inputs across DP ranks before MoE, scatter outputs after.
|
||||
Uses `max_tokens_across_dp` from forward_context for padding alignment.
|
||||
MoE communication strategy using All-Gather + Reduce-Scatter on EP group.
|
||||
There are two sets of prepare and finalize:
|
||||
1. _prepare_with_dp_group/_finalize_with_dp_group: When sequence parallelism is not enabled,
|
||||
we gather inputs across DP ranks before MoE, scatter outputs after.
|
||||
The communication and calculation process is as follows (AG, AR and RS
|
||||
are abbreviations for All-Gather, All-Reduce and Reduce-Scatter, respectively):
|
||||
|
||||
Attn → TP AR → DP AG → MoE → DP RS → TP AR
|
||||
|
||||
2. _prepare_with_ep_group/_finalize_with_ep_group: When sequence parallelism is enabled,
|
||||
the above process becomes:
|
||||
|
||||
TP AG → Attn → TP RS → TP AG → DP AG → MoE → DP RS → TP RS
|
||||
|
||||
This strategy further combines TP AG + DP AG into EP All-Gather and TP RS + DP RS
|
||||
into EP Reduce-Scatter to improve communication performance. The optimized process is as follows:
|
||||
|
||||
TP AG → Attn → TP RS → EP AG → MoE → EP RS
|
||||
"""
|
||||
|
||||
def prepare(
|
||||
@ -289,6 +304,42 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False,
|
||||
gate=None
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
"""
|
||||
Preparation steps:
|
||||
AllGather hidden_states and router_logits to form global tensors.
|
||||
|
||||
Returns:
|
||||
Tuple of (global_hidden_states, global_router_logits, None)
|
||||
"""
|
||||
if enable_sp():
|
||||
return self._prepare_with_ep_group(hidden_states, router_logits)
|
||||
|
||||
return self._prepare_with_dp_group(hidden_states, router_logits,
|
||||
enable_shared_expert_dp,
|
||||
replace_allreduce, gate)
|
||||
|
||||
def _prepare_with_ep_group(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
hidden_states, True, True)
|
||||
router_logits = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
router_logits, True, True)
|
||||
|
||||
return hidden_states, router_logits, None, None
|
||||
|
||||
def _prepare_with_dp_group(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False,
|
||||
gate=None
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
"""
|
||||
@ -301,7 +352,6 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
|
||||
Tuple of (global_hidden_states, global_router_logits, None, None)
|
||||
"""
|
||||
self.enable_shared_expert_dp = enable_shared_expert_dp
|
||||
|
||||
if self.moe_config.dp_size > 1:
|
||||
forward_context = get_forward_context()
|
||||
max_tokens_across_dp = forward_context.max_tokens_across_dp
|
||||
@ -323,7 +373,6 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
|
||||
else:
|
||||
router_logits = self.moe_config.dp_group.all_gather(
|
||||
router_logits, 0)
|
||||
|
||||
return hidden_states, router_logits, None, None
|
||||
|
||||
def finalize(self,
|
||||
@ -331,6 +380,36 @@ class FusedMoEPrepareAndFinalizeWithAllGather(FusedMoEPrepareAndFinalize):
|
||||
reduce_results: bool,
|
||||
context_metadata: Optional[dict] = None) -> torch.Tensor:
|
||||
"""
|
||||
Finalization steps:
|
||||
Reduce Scatter hidden states.
|
||||
|
||||
Returns:
|
||||
Tensor with shape [local_num_tokens, hidden_size]
|
||||
"""
|
||||
if enable_sp():
|
||||
return self._finalize_with_ep_group(hidden_states)
|
||||
|
||||
return self._finalize_with_dp_group(hidden_states, reduce_results)
|
||||
|
||||
def _finalize_with_ep_group(self,
|
||||
hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Argument `reduce_results` is not needed in this func. Given sequence parallelism is enabled:
|
||||
1. Reduce_results is False usually happens when models have shared experts and need to
|
||||
allreduce hidden states after results of shared experts and routed experts are added in FusedMoe.
|
||||
We do reduce scatter for hidden states here, then skip allreudce in FusedMoe and add it to the
|
||||
result of shared experts.
|
||||
2 Reduce_results is True usually happens when model has no shared experts. We still do reduce scatter
|
||||
here, then skip allreudce in FusedMoe.
|
||||
"""
|
||||
hidden_states = torch.ops.vllm.maybe_pad_and_reduce(
|
||||
hidden_states, True)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def _finalize_with_dp_group(self, hidden_states: torch.Tensor,
|
||||
reduce_results: bool) -> torch.Tensor:
|
||||
"""
|
||||
Finalization steps:
|
||||
1. If DP > 1 and not shared expert, reduce-scatter output across DP group.
|
||||
2. Slice to original local token count.
|
||||
|
@ -1,7 +1,9 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch_npu
|
||||
from vllm.distributed import (tensor_model_parallel_all_gather,
|
||||
from vllm.distributed import (get_dp_group, get_ep_group,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_all_reduce,
|
||||
tensor_model_parallel_reduce_scatter)
|
||||
from vllm.forward_context import get_forward_context
|
||||
@ -13,8 +15,10 @@ from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
||||
from vllm_ascend.utils import npu_stream_switch, prefetch_stream
|
||||
|
||||
|
||||
def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor,
|
||||
label: bool) -> torch.Tensor:
|
||||
def _maybe_all_gather_and_maybe_unpad_impl(
|
||||
x: torch.Tensor,
|
||||
label: bool,
|
||||
is_ep_comm: bool = False) -> torch.Tensor:
|
||||
try:
|
||||
forward_context = get_forward_context()
|
||||
except AssertionError:
|
||||
@ -22,27 +26,66 @@ def _maybe_all_gather_and_maybe_unpad_impl(x: torch.Tensor,
|
||||
|
||||
sp_enabled = forward_context.sp_enabled
|
||||
if sp_enabled and label:
|
||||
x = tensor_model_parallel_all_gather(x, 0)
|
||||
pad_size = forward_context.pad_size
|
||||
if pad_size > 0:
|
||||
x = x[:-pad_size, :]
|
||||
dp_metadata = forward_context.dp_metadata
|
||||
if dp_metadata is None or not is_ep_comm:
|
||||
x = tensor_model_parallel_all_gather(x, 0)
|
||||
pad_size = forward_context.pad_size
|
||||
if pad_size > 0:
|
||||
x = x[:-pad_size, :]
|
||||
else:
|
||||
x = get_ep_group().all_gather(x, 0)
|
||||
# unpad
|
||||
num_tokens_across_dp_cpu = dp_metadata.num_tokens_across_dp_cpu
|
||||
result = torch.empty(
|
||||
(num_tokens_across_dp_cpu.sum(), *x.shape[1:]),
|
||||
device=x.device,
|
||||
dtype=x.dtype)
|
||||
dp_size = get_dp_group().world_size
|
||||
x = x.view(dp_size, forward_context.padded_length, *x.shape[1:])
|
||||
offset = 0
|
||||
for idx in range(dp_size):
|
||||
num_tokens_dp = num_tokens_across_dp_cpu[idx]
|
||||
result[offset:offset +
|
||||
num_tokens_dp, :] = x[idx, :num_tokens_dp, :]
|
||||
offset += num_tokens_dp
|
||||
x = result
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def _maybe_pad_and_reduce_impl(x: torch.Tensor) -> torch.Tensor:
|
||||
def _maybe_pad_and_reduce_impl(x: torch.Tensor,
|
||||
is_ep_comm: bool = False) -> torch.Tensor:
|
||||
try:
|
||||
forward_context = get_forward_context()
|
||||
except AssertionError:
|
||||
return tensor_model_parallel_all_reduce(x)
|
||||
|
||||
sp_enabled = forward_context.sp_enabled
|
||||
if sp_enabled:
|
||||
if not forward_context.sp_enabled:
|
||||
return tensor_model_parallel_all_reduce(x)
|
||||
|
||||
dp_metadata = forward_context.dp_metadata
|
||||
if dp_metadata is None or not is_ep_comm:
|
||||
pad_size = forward_context.pad_size
|
||||
if pad_size > 0:
|
||||
x = F.pad(x, (0, 0, 0, pad_size))
|
||||
return tensor_model_parallel_reduce_scatter(x, 0)
|
||||
else:
|
||||
return tensor_model_parallel_all_reduce(x)
|
||||
# padding
|
||||
dp_size = get_dp_group().world_size
|
||||
num_tokens_across_dp_cpu = \
|
||||
get_forward_context().dp_metadata.num_tokens_across_dp_cpu
|
||||
padded_x = torch.empty(
|
||||
(dp_size, forward_context.padded_length, *x.shape[1:]),
|
||||
device=x.device,
|
||||
dtype=x.dtype)
|
||||
offset = 0
|
||||
for idx in range(dp_size):
|
||||
num_tokens_dp = num_tokens_across_dp_cpu[idx]
|
||||
padded_x[idx, :num_tokens_dp] = x[offset:offset + num_tokens_dp]
|
||||
offset += num_tokens_dp
|
||||
|
||||
return get_ep_group().reduce_scatter(padded_x.view(-1, *x.shape[1:]),
|
||||
0)
|
||||
|
||||
|
||||
def _maybe_prefetch_mlp_gate_up_proj_impl(x_dependency: torch.Tensor,
|
||||
@ -71,6 +114,33 @@ def _maybe_prefetch_mlp_gate_up_proj_impl(x_dependency: torch.Tensor,
|
||||
return
|
||||
|
||||
|
||||
def _maybe_all_gather_and_maybe_unpad_fake(
|
||||
x: torch.Tensor,
|
||||
label: bool,
|
||||
is_ep_comm: bool = False) -> torch.Tensor:
|
||||
|
||||
if get_forward_context().sp_enabled and label:
|
||||
return torch.empty(
|
||||
(x.shape[0] * get_tensor_model_parallel_world_size(),
|
||||
*x.shape[1:]),
|
||||
device=x.device,
|
||||
dtype=x.dtype)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def _maybe_pad_and_reduce_fake(x: torch.Tensor,
|
||||
is_ep_comm: bool = False) -> torch.Tensor:
|
||||
if get_forward_context().sp_enabled:
|
||||
return torch.empty(
|
||||
(x.shape[0] // get_tensor_model_parallel_world_size(),
|
||||
*x.shape[1:]),
|
||||
device=x.device,
|
||||
dtype=x.dtype)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def _maybe_prefetch_mlp_gate_up_proj_impl_fake(x_dependency: torch.Tensor,
|
||||
prefix: str) -> None:
|
||||
return
|
||||
@ -158,7 +228,8 @@ def _maybe_all_reduce_tensor_model_parallel_impl(
|
||||
final_hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
forward_context = get_forward_context()
|
||||
moe_comm_type = forward_context.moe_comm_type
|
||||
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2}:
|
||||
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2
|
||||
} or forward_context.sp_enabled:
|
||||
return final_hidden_states
|
||||
else:
|
||||
return tensor_model_parallel_all_reduce(final_hidden_states)
|
||||
@ -166,13 +237,13 @@ def _maybe_all_reduce_tensor_model_parallel_impl(
|
||||
|
||||
direct_register_custom_op(op_name="maybe_all_gather_and_maybe_unpad",
|
||||
op_func=_maybe_all_gather_and_maybe_unpad_impl,
|
||||
fake_impl=lambda x, label: x,
|
||||
fake_impl=_maybe_all_gather_and_maybe_unpad_fake,
|
||||
mutates_args=[],
|
||||
dispatch_key="PrivateUse1")
|
||||
|
||||
direct_register_custom_op(op_name="maybe_pad_and_reduce",
|
||||
op_func=_maybe_pad_and_reduce_impl,
|
||||
fake_impl=lambda x: x,
|
||||
fake_impl=_maybe_pad_and_reduce_fake,
|
||||
mutates_args=[],
|
||||
dispatch_key="PrivateUse1")
|
||||
|
||||
|
@ -31,7 +31,7 @@ from vllm_ascend.ascend_config import (check_ascend_config, get_ascend_config,
|
||||
init_ascend_config)
|
||||
from vllm_ascend.torchair.utils import (check_torchair_cache_exist,
|
||||
delete_torchair_cache_file)
|
||||
from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, is_310p,
|
||||
from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, enable_sp, is_310p,
|
||||
update_aclgraph_sizes)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -211,6 +211,21 @@ class NPUPlatform(Platform):
|
||||
|
||||
# set cudaprah sizes before extending `compilation_config.splitting_ops`
|
||||
vllm_config._set_cudagraph_sizes()
|
||||
# TODO delete graph size update here when compilation_config.pass_config.enable_sequence_parallelism
|
||||
# is supported by vllm-ascend.
|
||||
if vllm_config.parallel_config.tensor_parallel_size > 1 and not vllm_config.model_config.enforce_eager and \
|
||||
enable_sp(vllm_config):
|
||||
original_sizes = compilation_config.cudagraph_capture_sizes
|
||||
sp_aclgraph_sizes = \
|
||||
vllm_config.update_sizes_for_sequence_parallelism(original_sizes)
|
||||
assert sp_aclgraph_sizes, (
|
||||
f"cudagraph_capture_sizes {original_sizes} does not contain"
|
||||
f"values that are multiples of tp_size "
|
||||
f"{vllm_config.parallel_config.tensor_parallel_size}")
|
||||
if len(sp_aclgraph_sizes) != len(original_sizes):
|
||||
compilation_config.cudagraph_capture_sizes = sp_aclgraph_sizes
|
||||
vllm_config.compilation_config.init_with_cudagraph_sizes(
|
||||
sp_aclgraph_sizes)
|
||||
|
||||
# TODO: Full graph is fully supported later, and the default value will be set to full graph.
|
||||
if compilation_config.cudagraph_mode == CUDAGraphMode.FULL_AND_PIECEWISE:
|
||||
|
@ -55,6 +55,7 @@ _PREFETCH_STREAM = None
|
||||
_ASCEND_CUSTOMOP_IS_REIGISTERED = False
|
||||
_DEFAULT_BUFFER_SIZE = 200
|
||||
_MIN_DP_BUFFER_SIZE = 50
|
||||
_IS_MOE_MODEL = None
|
||||
|
||||
|
||||
def is_310p():
|
||||
@ -609,12 +610,24 @@ def enable_sp(vllm_config=None) -> bool:
|
||||
vllm_config = get_current_vllm_config()
|
||||
return (
|
||||
vllm_config.compilation_config.pass_config.enable_sequence_parallelism
|
||||
or envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM)
|
||||
or envs_ascend.VLLM_ASCEND_ENABLE_FLASHCOMM1
|
||||
# Flash comm 1 should be enabled by env VLLM_ASCEND_ENABLE_FLASHCOMM1
|
||||
# We retain the env VLLM_ASCEND_ENABLE_FLASHCOMM here for backward compatibility.
|
||||
or bool(int(os.getenv("VLLM_ASCEND_ENABLE_FLASHCOMM", '0'))))
|
||||
|
||||
|
||||
# TODO remove it after vllm has this func
|
||||
def shared_expert_dp_enabled() -> bool:
|
||||
return get_ascend_config().enable_shared_expert_dp or enable_sp()
|
||||
|
||||
|
||||
def is_moe_model(vllm_config: VllmConfig):
|
||||
config = vllm_config.model_config.hf_config
|
||||
return any('experts' in key.lower() for key in config.to_dict())
|
||||
global _IS_MOE_MODEL
|
||||
if _IS_MOE_MODEL is None:
|
||||
config = vllm_config.model_config.hf_config
|
||||
_IS_MOE_MODEL = any('experts' in key.lower()
|
||||
for key in config.to_dict())
|
||||
return _IS_MOE_MODEL
|
||||
|
||||
|
||||
def weak_ref_tensor(tensor: Any) -> Any:
|
||||
|
@ -20,6 +20,7 @@
|
||||
import copy
|
||||
import gc
|
||||
import itertools
|
||||
import math
|
||||
import re
|
||||
import time
|
||||
from collections import defaultdict
|
||||
@ -128,8 +129,8 @@ from vllm_ascend.spec_decode.interface import SpecDcodeType
|
||||
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
||||
AscendSocVersion, ProfileExecuteDuration,
|
||||
get_ascend_soc_version, is_310p, is_enable_nz,
|
||||
lmhead_tp_enable)
|
||||
enable_sp, get_ascend_soc_version, is_310p,
|
||||
is_enable_nz, lmhead_tp_enable)
|
||||
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -1210,6 +1211,14 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
# Add padding to the batch size.
|
||||
num_input_tokens = self.vllm_config.pad_for_cudagraph(
|
||||
total_num_scheduled_tokens)
|
||||
elif self.use_aclgraph and enable_sp(self.vllm_config):
|
||||
# When using aclgraph, if total_num_scheduled_tokens exceeds the maximum graph size,
|
||||
# the model will fall back to running its FX graph in eager mode.
|
||||
# In this case, when sequence parallelism is enabled, we need to pad tokens to align
|
||||
# with tp_size because pad_size cannot be captured by the FX graph
|
||||
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
|
||||
num_input_tokens = math.ceil(
|
||||
total_num_scheduled_tokens / tp_size) * tp_size
|
||||
else:
|
||||
# Eager mode.
|
||||
num_input_tokens = total_num_scheduled_tokens
|
||||
@ -1850,7 +1859,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
raise ValueError(f"Unsupported soc_version: {soc_version}")
|
||||
|
||||
if moe_comm_type == MoECommType.ALLGATHER and with_prefill:
|
||||
moe_comm_type = MoECommType.NAIVE_MULTICAST
|
||||
if enable_sp():
|
||||
moe_comm_type = MoECommType.ALLGATHER
|
||||
else:
|
||||
moe_comm_type = MoECommType.NAIVE_MULTICAST
|
||||
|
||||
# PanguProMoE only supports allgather
|
||||
if model_type == "PanguProMoE":
|
||||
@ -2314,6 +2326,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL
|
||||
}
|
||||
|
||||
# In multi-DP scenarios, there may be situations where all DP groups are executing dummy runs.
|
||||
# If sequence parallelism is enabled, it is essential to ensure that num_tokens is divisible by tp_size.
|
||||
if self.use_aclgraph and enable_sp(self.vllm_config):
|
||||
tp_size = self.vllm_config.parallel_config.tensor_parallel_size
|
||||
num_tokens = math.ceil(num_tokens / tp_size) * tp_size
|
||||
|
||||
# Padding for DP
|
||||
(num_tokens, num_tokens_across_dp, with_prefill,
|
||||
_) = self._sync_metadata_across_dp(num_tokens, with_prefill, False)
|
||||
|
Reference in New Issue
Block a user