Files
vllm-ascend/vllm_ascend/ops/sequence_parallel.py
lbk-sys c611291661 【main】SP For Qwen3 MoE (#2209)
### What this PR does / why we need it?
Qwen3 MoE supports SP. In scenarios like AlltoAll, AlltoAllv, and MC2,
replacing AllReduce with Reduce-Scatter and AllGather achieves
computational benefits in norm operations while saving one AllGather
communication. This feature is enabled during the P-phase and delivers
notable gains in long-sequence scenarios (e.g., 16k–25k), with
performance improvements reaching 5%–10%.
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?
``` 
compilation_config={
    "pass_config":{
        "enable_sequence_parallelism": True
    }
},
enable_expert_parallel=True,
```

- vLLM version: v0.10.0
- vLLM main:
9edd1db02b

---------

Signed-off-by: libaokui <libaokui@huawei.com>
Co-authored-by: libaokui <libaokui@huawei.com>
2025-08-07 09:15:49 +08:00

121 lines
4.4 KiB
Python

import torch
from torch.nn import functional as F
from vllm.distributed import (get_tensor_model_parallel_world_size,
get_tp_group, tensor_model_parallel_all_gather,
tensor_model_parallel_reduce_scatter)
from vllm.forward_context import get_forward_context
from vllm_ascend.platform import NPUPlatform
class MetadataForPadding:
def __init__(self,
padding_flag=False,
lengths_sum_padding=0,
lengths_sum_unpadding=0,
pad_size=0,
not_dummy_and_is_prefill=False):
self.padding_flag = padding_flag
self.not_dummy_and_is_prefill = not_dummy_and_is_prefill
self.lengths_sum_padding = lengths_sum_padding
self.lengths_sum_unpadding = lengths_sum_unpadding
self.pad_size = pad_size
self.tp_size = get_tp_group().world_size
self.tp_rank_in_group = get_tp_group().rank_in_group
assert self.lengths_sum_padding % self.tp_size == 0
self.slice_size = self.lengths_sum_padding // self.tp_size
self.mc2_mask = torch.zeros(
self.lengths_sum_padding,
dtype=torch.bool,
device=NPUPlatform.device_type,
)
self.mc2_mask[:lengths_sum_unpadding] = True
def padding_aligned_reduce_scatter(self,
data: torch.Tensor) -> torch.Tensor:
if self.padding_flag:
pad_size = self.pad_size
padded_data = F.pad(data, (0, 0, 0, pad_size))
else:
padded_data = data
padded_data_reduce_scatter = tensor_model_parallel_reduce_scatter(
padded_data, 0)
return padded_data_reduce_scatter
def allgather_unpadding_aligned(self,
padded_data: torch.Tensor) -> torch.Tensor:
padded_data_allgather = tensor_model_parallel_all_gather(
padded_data, 0)
if self.padding_flag:
lengths_sum_unpadding = self.lengths_sum_unpadding
unpadding_data = padded_data_allgather[:lengths_sum_unpadding]
else:
unpadding_data = padded_data_allgather
return unpadding_data
def padding_slice(self, data: torch.Tensor) -> torch.Tensor:
padded_data = F.pad(data, (0, 0, 0, self.pad_size))
start = self.tp_rank_in_group * self.slice_size
end = start + self.slice_size
slice_data = padded_data[start:end]
return slice_data
def padding_aligned_scatter(self, data: torch.Tensor) -> torch.Tensor:
if self.padding_flag:
pad_size = self.pad_size
padded_data = F.pad(data, (0, 0, 0, pad_size))
else:
padded_data = data
# padded_data = data
padded_data = torch.tensor_split(padded_data, self.tp_size, dim=0)
padded_data_reduce_scatter = padded_data[self.tp_rank_in_group]
return padded_data_reduce_scatter
def init_metadata_for_sp(input_ids, enable_sequence_parallelism):
if not enable_sequence_parallelism:
return MetadataForPadding(padding_flag=False,
not_dummy_and_is_prefill=False)
is_perifll = 0
attn_metadata = get_forward_context().attn_metadata
tp_size = get_tensor_model_parallel_world_size()
if attn_metadata is not None:
if hasattr(attn_metadata,
'is_only_prefill') and attn_metadata.is_only_prefill:
is_perifll = 1
if hasattr(attn_metadata,
'num_prefills') and attn_metadata.num_prefills > 0:
is_perifll = 1
if is_perifll:
lengths_sum_unpadding = input_ids.shape[0]
lengths_sum_padding = (
(lengths_sum_unpadding + tp_size - 1) // tp_size) * tp_size
if lengths_sum_unpadding == lengths_sum_padding:
padding_flag = False
else:
padding_flag = True
pad_size = lengths_sum_padding - lengths_sum_unpadding
_metadata_for_padding = MetadataForPadding(
lengths_sum_unpadding=lengths_sum_unpadding,
lengths_sum_padding=lengths_sum_padding,
padding_flag=padding_flag,
pad_size=pad_size,
not_dummy_and_is_prefill=True)
return _metadata_for_padding
return MetadataForPadding(padding_flag=False,
not_dummy_and_is_prefill=False)