fixed: Modified the topkgating function and modified the test_moe file for testing (#7163)

Since the previous PR encountered the DCO problem and could not be
solved for some reason, I resubmitted a completely identical PR but
without the problem.

---------

Signed-off-by: xiongjyu <xiongjyu@gmail.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Olatunji Ruwase <tjruwase@gmail.com>
Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com>
This commit is contained in:
xiongjyu
2025-06-07 07:42:41 +08:00
committed by GitHub
parent 24a1d8f936
commit 770967f5f0
2 changed files with 10 additions and 0 deletions

View File

@ -429,6 +429,7 @@ def topkgating(
tp = 1 if groups.mpu is None else bwc_tensor_model_parallel_world_size(mpu=groups.mpu)
new_capacity = torch.ceil(new_capacity / tp).mul(tp).to(new_capacity.dtype)
capacity = new_capacity
locations = torch.cumsum(mask, dim=0) - 1
# normalize gates
gates_masked = gates * mask

View File

@ -253,6 +253,15 @@ class TestTopkGate(DistributedTest):
position_dispatch_res = topkgating(logits2, 3, 1, min_capacity=1, drop_policy='position')[2]
check_equal(logits2, 2, position_sec_sparse, position_dispatch_res)
#s=4 e=4 topk=2 drop_tokens=False
logits3 = torch.tensor([[0.95, 0.85, 0.90, 0.80], [0.70, 0.65, 0.75, 0.60], [0.50, 0.55, 0.45, 0.40],
[0.35, 0.30, 0.25, 0.20]])
logits3 *= dist.get_rank() + 1
dispatch_res = topkgating(logits3, 2, 1, min_capacity=1, drop_tokens=False)[2]
sec_sparse = torch.tensor([[0, 0, 0], [0, 2, 0], [1, 0, 1], [1, 2, 1], [2, 0, 2], [2, 1, 0], [3, 0, 3],
[3, 1, 1]])
check_equal(logits3, 4, sec_sparse, dispatch_res)
class TestExpertWeightGradWithZero(DistributedTest):
world_size = 2