mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
fixed: Modified the topkgating function and modified the test_moe file for testing (#7163)
Since the previous PR encountered the DCO problem and could not be solved for some reason, I resubmitted a completely identical PR but without the problem. --------- Signed-off-by: xiongjyu <xiongjyu@gmail.com> Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Olatunji Ruwase <tjruwase@gmail.com> Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com>
This commit is contained in:
@ -429,6 +429,7 @@ def topkgating(
|
||||
tp = 1 if groups.mpu is None else bwc_tensor_model_parallel_world_size(mpu=groups.mpu)
|
||||
new_capacity = torch.ceil(new_capacity / tp).mul(tp).to(new_capacity.dtype)
|
||||
capacity = new_capacity
|
||||
locations = torch.cumsum(mask, dim=0) - 1
|
||||
|
||||
# normalize gates
|
||||
gates_masked = gates * mask
|
||||
|
@ -253,6 +253,15 @@ class TestTopkGate(DistributedTest):
|
||||
position_dispatch_res = topkgating(logits2, 3, 1, min_capacity=1, drop_policy='position')[2]
|
||||
check_equal(logits2, 2, position_sec_sparse, position_dispatch_res)
|
||||
|
||||
#s=4 e=4 topk=2 drop_tokens=False
|
||||
logits3 = torch.tensor([[0.95, 0.85, 0.90, 0.80], [0.70, 0.65, 0.75, 0.60], [0.50, 0.55, 0.45, 0.40],
|
||||
[0.35, 0.30, 0.25, 0.20]])
|
||||
logits3 *= dist.get_rank() + 1
|
||||
dispatch_res = topkgating(logits3, 2, 1, min_capacity=1, drop_tokens=False)[2]
|
||||
sec_sparse = torch.tensor([[0, 0, 0], [0, 2, 0], [1, 0, 1], [1, 2, 1], [2, 0, 2], [2, 1, 0], [3, 0, 3],
|
||||
[3, 1, 1]])
|
||||
check_equal(logits3, 4, sec_sparse, dispatch_res)
|
||||
|
||||
|
||||
class TestExpertWeightGradWithZero(DistributedTest):
|
||||
world_size = 2
|
||||
|
Reference in New Issue
Block a user