mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
35 lines
1.2 KiB
Python
35 lines
1.2 KiB
Python
'''
|
|
Copyright 2020 The Microsoft DeepSpeed Team
|
|
'''
|
|
|
|
import torch
|
|
import copy
|
|
|
|
|
|
class Experts(torch.nn.Module):
|
|
|
|
def __init__(self, expert, num_local_experts=1, expert_group_name=None):
|
|
super(Experts, self).__init__()
|
|
|
|
self.deepspeed_experts = torch.nn.ModuleList([copy.deepcopy(expert) for i in range(num_local_experts)])
|
|
self.num_local_experts = num_local_experts
|
|
|
|
# TODO: revisit allreduce for moe.gate...
|
|
for expert in self.deepspeed_experts:
|
|
# TODO: Create param groups to handle expert + data case (e.g. param.group = moe_group)
|
|
for name, param in expert.named_parameters():
|
|
param.allreduce = False
|
|
param.group_name = expert_group_name
|
|
|
|
def forward(self, inputs):
|
|
chunks = inputs.chunk(self.num_local_experts, dim=1)
|
|
expert_outputs = []
|
|
for chunk, expert in zip(chunks, self.deepspeed_experts):
|
|
out = expert(chunk)
|
|
if type(out) is tuple:
|
|
out = out[0] # Ignore the bias term for now
|
|
expert_outputs += [out]
|
|
|
|
expert_output = torch.cat(expert_outputs, dim=1)
|
|
return expert_output
|