mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
[TiledMLP] moe support (#7622)
MoE routers seem to drop the `bs` dimension in `x` so the `[bs, seqlen, hidden_size]` is no longer expected. support that use-case. Signed-off-by: Stas Bekman <stas@stason.org>
This commit is contained in:
@ -836,10 +836,11 @@ class TiledMLP(torch.autograd.Function):
|
||||
ctx.compute_params = [p for p in compute_params if p.requires_grad]
|
||||
ctx.save_for_backward(x)
|
||||
|
||||
x_shards = list(torch.chunk(x, chunks=shards, dim=1))
|
||||
# x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size] (moe experts)
|
||||
x_shards = list(torch.chunk(x, chunks=shards, dim=-2))
|
||||
with torch.no_grad():
|
||||
output_shards = [fn(self, x_shard) for x_shard in x_shards]
|
||||
output_unsharded = torch.cat(output_shards, dim=1)
|
||||
output_unsharded = torch.cat(output_shards, dim=-2)
|
||||
|
||||
return output_unsharded
|
||||
|
||||
@ -856,7 +857,9 @@ class TiledMLP(torch.autograd.Function):
|
||||
# detach() unsets `x.requires_grad`, so restore it
|
||||
x.requires_grad_(x_requires_grad)
|
||||
|
||||
bs, seqlen, hidden_size = x.shape
|
||||
# x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size] (moe experts)
|
||||
hidden_size = x.shape[-1]
|
||||
x_shape_orig = x.shape
|
||||
|
||||
# flatten bs+seqlen to avoid having stride issues when narrowing into seqlen w/ bs>1
|
||||
x = x.view(-1, hidden_size)
|
||||
@ -890,7 +893,7 @@ class TiledMLP(torch.autograd.Function):
|
||||
torch.autograd.backward(output, incoming_grad_shard)
|
||||
|
||||
# unflatten
|
||||
x_grad = x_grad.view(bs, -1, hidden_size)
|
||||
x_grad = x_grad.view(x_shape_orig)
|
||||
|
||||
return (None, None, x_grad, None, None)
|
||||
|
||||
|
Reference in New Issue
Block a user