mirror of
https://github.com/vllm-project/vllm-ascend.git
synced 2025-10-20 21:53:54 +08:00
### What this PR does / why we need it?
Qwen3 MoE supports SP. In scenarios like AlltoAll, AlltoAllv, and MC2,
replacing AllReduce with Reduce-Scatter and AllGather achieves
computational benefits in norm operations while saving one AllGather
communication. This feature is enabled during the P-phase and delivers
notable gains in long-sequence scenarios (e.g., 16k–25k), with
performance improvements reaching 5%–10%.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
```
compilation_config={
"pass_config":{
"enable_sequence_parallelism": True
}
},
enable_expert_parallel=True,
```
- vLLM version: v0.10.0
- vLLM main:
9edd1db02b
---------
Signed-off-by: libaokui <libaokui@huawei.com>
Co-authored-by: libaokui <libaokui@huawei.com>
121 lines
4.4 KiB
Python
121 lines
4.4 KiB
Python
import torch
|
|
from torch.nn import functional as F
|
|
from vllm.distributed import (get_tensor_model_parallel_world_size,
|
|
get_tp_group, tensor_model_parallel_all_gather,
|
|
tensor_model_parallel_reduce_scatter)
|
|
from vllm.forward_context import get_forward_context
|
|
|
|
from vllm_ascend.platform import NPUPlatform
|
|
|
|
|
|
class MetadataForPadding:
|
|
|
|
def __init__(self,
|
|
padding_flag=False,
|
|
lengths_sum_padding=0,
|
|
lengths_sum_unpadding=0,
|
|
pad_size=0,
|
|
not_dummy_and_is_prefill=False):
|
|
self.padding_flag = padding_flag
|
|
self.not_dummy_and_is_prefill = not_dummy_and_is_prefill
|
|
|
|
self.lengths_sum_padding = lengths_sum_padding
|
|
self.lengths_sum_unpadding = lengths_sum_unpadding
|
|
self.pad_size = pad_size
|
|
|
|
self.tp_size = get_tp_group().world_size
|
|
self.tp_rank_in_group = get_tp_group().rank_in_group
|
|
|
|
assert self.lengths_sum_padding % self.tp_size == 0
|
|
self.slice_size = self.lengths_sum_padding // self.tp_size
|
|
|
|
self.mc2_mask = torch.zeros(
|
|
self.lengths_sum_padding,
|
|
dtype=torch.bool,
|
|
device=NPUPlatform.device_type,
|
|
)
|
|
self.mc2_mask[:lengths_sum_unpadding] = True
|
|
|
|
def padding_aligned_reduce_scatter(self,
|
|
data: torch.Tensor) -> torch.Tensor:
|
|
if self.padding_flag:
|
|
pad_size = self.pad_size
|
|
padded_data = F.pad(data, (0, 0, 0, pad_size))
|
|
else:
|
|
padded_data = data
|
|
padded_data_reduce_scatter = tensor_model_parallel_reduce_scatter(
|
|
padded_data, 0)
|
|
|
|
return padded_data_reduce_scatter
|
|
|
|
def allgather_unpadding_aligned(self,
|
|
padded_data: torch.Tensor) -> torch.Tensor:
|
|
padded_data_allgather = tensor_model_parallel_all_gather(
|
|
padded_data, 0)
|
|
if self.padding_flag:
|
|
lengths_sum_unpadding = self.lengths_sum_unpadding
|
|
unpadding_data = padded_data_allgather[:lengths_sum_unpadding]
|
|
else:
|
|
unpadding_data = padded_data_allgather
|
|
return unpadding_data
|
|
|
|
def padding_slice(self, data: torch.Tensor) -> torch.Tensor:
|
|
|
|
padded_data = F.pad(data, (0, 0, 0, self.pad_size))
|
|
start = self.tp_rank_in_group * self.slice_size
|
|
end = start + self.slice_size
|
|
slice_data = padded_data[start:end]
|
|
|
|
return slice_data
|
|
|
|
def padding_aligned_scatter(self, data: torch.Tensor) -> torch.Tensor:
|
|
if self.padding_flag:
|
|
pad_size = self.pad_size
|
|
padded_data = F.pad(data, (0, 0, 0, pad_size))
|
|
else:
|
|
padded_data = data
|
|
# padded_data = data
|
|
padded_data = torch.tensor_split(padded_data, self.tp_size, dim=0)
|
|
|
|
padded_data_reduce_scatter = padded_data[self.tp_rank_in_group]
|
|
|
|
return padded_data_reduce_scatter
|
|
|
|
|
|
def init_metadata_for_sp(input_ids, enable_sequence_parallelism):
|
|
if not enable_sequence_parallelism:
|
|
return MetadataForPadding(padding_flag=False,
|
|
not_dummy_and_is_prefill=False)
|
|
|
|
is_perifll = 0
|
|
attn_metadata = get_forward_context().attn_metadata
|
|
tp_size = get_tensor_model_parallel_world_size()
|
|
if attn_metadata is not None:
|
|
if hasattr(attn_metadata,
|
|
'is_only_prefill') and attn_metadata.is_only_prefill:
|
|
is_perifll = 1
|
|
if hasattr(attn_metadata,
|
|
'num_prefills') and attn_metadata.num_prefills > 0:
|
|
is_perifll = 1
|
|
|
|
if is_perifll:
|
|
lengths_sum_unpadding = input_ids.shape[0]
|
|
lengths_sum_padding = (
|
|
(lengths_sum_unpadding + tp_size - 1) // tp_size) * tp_size
|
|
if lengths_sum_unpadding == lengths_sum_padding:
|
|
padding_flag = False
|
|
else:
|
|
padding_flag = True
|
|
pad_size = lengths_sum_padding - lengths_sum_unpadding
|
|
_metadata_for_padding = MetadataForPadding(
|
|
lengths_sum_unpadding=lengths_sum_unpadding,
|
|
lengths_sum_padding=lengths_sum_padding,
|
|
padding_flag=padding_flag,
|
|
pad_size=pad_size,
|
|
not_dummy_and_is_prefill=True)
|
|
|
|
return _metadata_for_padding
|
|
|
|
return MetadataForPadding(padding_flag=False,
|
|
not_dummy_and_is_prefill=False)
|