[TiledMLP] moe support (#7622)

MoE routers seem to drop the `bs` dimension in `x` so the `[bs, seqlen,
hidden_size]` is no longer expected. support that use-case.

Signed-off-by: Stas Bekman <stas@stason.org>
This commit is contained in:
Stas Bekman
2025-10-07 06:33:34 -07:00
committed by GitHub
parent 1ae1cdd8e4
commit 1b08325da3

View File

@ -836,10 +836,11 @@ class TiledMLP(torch.autograd.Function):
ctx.compute_params = [p for p in compute_params if p.requires_grad]
ctx.save_for_backward(x)
x_shards = list(torch.chunk(x, chunks=shards, dim=1))
# x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size] (moe experts)
x_shards = list(torch.chunk(x, chunks=shards, dim=-2))
with torch.no_grad():
output_shards = [fn(self, x_shard) for x_shard in x_shards]
output_unsharded = torch.cat(output_shards, dim=1)
output_unsharded = torch.cat(output_shards, dim=-2)
return output_unsharded
@ -856,7 +857,9 @@ class TiledMLP(torch.autograd.Function):
# detach() unsets `x.requires_grad`, so restore it
x.requires_grad_(x_requires_grad)
bs, seqlen, hidden_size = x.shape
# x.shape could be [bs, seqlen, hidden_size] or [seqlen, hidden_size] (moe experts)
hidden_size = x.shape[-1]
x_shape_orig = x.shape
# flatten bs+seqlen to avoid having stride issues when narrowing into seqlen w/ bs>1
x = x.view(-1, hidden_size)
@ -890,7 +893,7 @@ class TiledMLP(torch.autograd.Function):
torch.autograd.backward(output, incoming_grad_shard)
# unflatten
x_grad = x_grad.view(bs, -1, hidden_size)
x_grad = x_grad.view(x_shape_orig)
return (None, None, x_grad, None, None)