[bugfix][torchair] fix wasted NPU memory buffer allocation for quantized deepseek with unquantized MTP layer (#3068)

### What this PR does / why we need it?
While running quantized deepseek models with unquantized MTP layer, free
NPU memory abnormally decreases for `2*HCCL_BUFFSIZE` bytes. This
results from the wasted VRAM buffer allocation casued by calling
`dist.all_to_all_single` without correct device process group argument.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
We run vllm online serving with quantized deepseek-r1 and unquantized
MTP layer, and observed that free_memory increased without redundat VRAM
buffer for HCCL communication op (all_to_all_single).

- vLLM version: v0.10.2
- vLLM main:
6d8246aaff

Signed-off-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
linfeng-yuan
2025-09-22 14:06:43 +08:00
committed by GitHub
parent 14b39d3c70
commit ffdd1a36e2

View File

@ -416,6 +416,7 @@ def torchair_fused_experts_with_all2all(
num_experts = w1.shape[0]
if expert_map is not None:
assert ep_group is not None, "ep_group must be provided when expert_map is given"
global_num_experts = len(expert_map) + global_redundant_expert_num
if hasattr(torch_npu, "npu_moe_init_routing_quant"):
quantized_tokens, expanded_row_idx, global_expert_tokens, _, token_scales = torch_npu.npu_moe_init_routing_quant(
@ -435,8 +436,9 @@ def torchair_fused_experts_with_all2all(
gather_sizes = global_expert_tokens.new_empty(
global_expert_tokens.shape[0])
dist.all_to_all_single(gather_sizes, global_expert_tokens)
dist.all_to_all_single(gather_sizes,
global_expert_tokens,
group=ep_group.device_group)
token_counts_combined = torch.stack(
[gather_sizes, global_expert_tokens], dim=0)
token_counts_combined = token_counts_combined.view(
@ -451,10 +453,16 @@ def torchair_fused_experts_with_all2all(
gather_size_list = token_counts_combined_cpu[1]
scatter_size_list = token_counts_combined_cpu[0]
dist.all_to_all_single(gathered_tokens, quantized_tokens,
scatter_size_list, gather_size_list)
dist.all_to_all_single(dynamic_scale, token_scales, scatter_size_list,
gather_size_list)
dist.all_to_all_single(gathered_tokens,
quantized_tokens,
scatter_size_list,
gather_size_list,
group=ep_group.device_group)
dist.all_to_all_single(dynamic_scale,
token_scales,
scatter_size_list,
gather_size_list,
group=ep_group.device_group)
hidden_states, dynamic_scale, inverse_indices, expert_tokens = torch_npu.npu_moe_re_routing(
gathered_tokens,
@ -502,9 +510,11 @@ def torchair_fused_experts_with_all2all(
index=inverse_indices.to(torch.float32).argsort().to(torch.int32))
hidden_states = reordered_outputs.new_empty(*quantized_tokens.shape)
dist.all_to_all_single(hidden_states, reordered_outputs,
gather_size_list, scatter_size_list)
dist.all_to_all_single(hidden_states,
reordered_outputs,
gather_size_list,
scatter_size_list,
group=ep_group.device_group)
final_hidden_states = torch_npu.npu_moe_finalize_routing(
hidden_states,
skip1=None,