mirror of
https://github.com/vllm-project/vllm-ascend.git
synced 2025-10-20 13:43:53 +08:00
[bugfix][torchair] fix wasted NPU memory buffer allocation for quantized deepseek with unquantized MTP layer (#3068)
### What this PR does / why we need it?
While running quantized deepseek models with unquantized MTP layer, free
NPU memory abnormally decreases for `2*HCCL_BUFFSIZE` bytes. This
results from the wasted VRAM buffer allocation casued by calling
`dist.all_to_all_single` without correct device process group argument.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
We run vllm online serving with quantized deepseek-r1 and unquantized
MTP layer, and observed that free_memory increased without redundat VRAM
buffer for HCCL communication op (all_to_all_single).
- vLLM version: v0.10.2
- vLLM main:
6d8246aaff
Signed-off-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
@ -416,6 +416,7 @@ def torchair_fused_experts_with_all2all(
|
||||
num_experts = w1.shape[0]
|
||||
|
||||
if expert_map is not None:
|
||||
assert ep_group is not None, "ep_group must be provided when expert_map is given"
|
||||
global_num_experts = len(expert_map) + global_redundant_expert_num
|
||||
if hasattr(torch_npu, "npu_moe_init_routing_quant"):
|
||||
quantized_tokens, expanded_row_idx, global_expert_tokens, _, token_scales = torch_npu.npu_moe_init_routing_quant(
|
||||
@ -435,8 +436,9 @@ def torchair_fused_experts_with_all2all(
|
||||
|
||||
gather_sizes = global_expert_tokens.new_empty(
|
||||
global_expert_tokens.shape[0])
|
||||
dist.all_to_all_single(gather_sizes, global_expert_tokens)
|
||||
|
||||
dist.all_to_all_single(gather_sizes,
|
||||
global_expert_tokens,
|
||||
group=ep_group.device_group)
|
||||
token_counts_combined = torch.stack(
|
||||
[gather_sizes, global_expert_tokens], dim=0)
|
||||
token_counts_combined = token_counts_combined.view(
|
||||
@ -451,10 +453,16 @@ def torchair_fused_experts_with_all2all(
|
||||
gather_size_list = token_counts_combined_cpu[1]
|
||||
scatter_size_list = token_counts_combined_cpu[0]
|
||||
|
||||
dist.all_to_all_single(gathered_tokens, quantized_tokens,
|
||||
scatter_size_list, gather_size_list)
|
||||
dist.all_to_all_single(dynamic_scale, token_scales, scatter_size_list,
|
||||
gather_size_list)
|
||||
dist.all_to_all_single(gathered_tokens,
|
||||
quantized_tokens,
|
||||
scatter_size_list,
|
||||
gather_size_list,
|
||||
group=ep_group.device_group)
|
||||
dist.all_to_all_single(dynamic_scale,
|
||||
token_scales,
|
||||
scatter_size_list,
|
||||
gather_size_list,
|
||||
group=ep_group.device_group)
|
||||
|
||||
hidden_states, dynamic_scale, inverse_indices, expert_tokens = torch_npu.npu_moe_re_routing(
|
||||
gathered_tokens,
|
||||
@ -502,9 +510,11 @@ def torchair_fused_experts_with_all2all(
|
||||
index=inverse_indices.to(torch.float32).argsort().to(torch.int32))
|
||||
|
||||
hidden_states = reordered_outputs.new_empty(*quantized_tokens.shape)
|
||||
dist.all_to_all_single(hidden_states, reordered_outputs,
|
||||
gather_size_list, scatter_size_list)
|
||||
|
||||
dist.all_to_all_single(hidden_states,
|
||||
reordered_outputs,
|
||||
gather_size_list,
|
||||
scatter_size_list,
|
||||
group=ep_group.device_group)
|
||||
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
||||
hidden_states,
|
||||
skip1=None,
|
||||
|
Reference in New Issue
Block a user