[alstn tutorial] support bs>1 (#7550)

Edit tutorial's demo code to support bs>1 and prevent div by zero
This commit is contained in:
Stas Bekman
2025-09-09 12:51:42 -07:00
committed by GitHub
parent 450b965efb
commit 533e834b0a

View File

@ -116,11 +116,11 @@ for iter, batch in enumerate(dl):
# differentiable weighted per-shard-loss aggregation across ranks
losses_per_rank = torch.distributed.nn.functional.all_gather(loss, group=sp_group)
# special dealing with SFT that has prompt tokens that aren't used in loss computation
good_tokens = sum((shift_labels != -100).view(-1))
good_tokens = (shift_labels != -100).view(-1).sum()
good_tokens_per_rank = torch.distributed.nn.functional.all_gather(good_tokens, group=sp_group)
total_loss = sum(losses_per_rank[rank] * good_tokens_per_rank[rank] for rank in range(sp_world_size))
total_good_tokens = sum(good_tokens_per_rank)
loss = total_loss / total_good_tokens
loss = total_loss / max(total_good_tokens, 1)
if dist.get_rank() == 0:
print(f"{iter}: {loss=}")
@ -185,11 +185,11 @@ Since each rank processes a segment we need to average loss. To get the gradient
# differentiable weighted per-shard-loss aggregation across ranks
losses_per_rank = torch.distributed.nn.functional.all_gather(loss, group=sp_group)
# special dealing with SFT that has prompt tokens that aren't used in loss computation
good_tokens = sum((shift_labels != -100).view(-1))
good_tokens = (shift_labels != -100).view(-1).sum()
good_tokens_per_rank = torch.distributed.nn.functional.all_gather(good_tokens, group=sp_group)
total_loss = sum(losses_per_rank[rank] * good_tokens_per_rank[rank] for rank in range(sp_world_size))
total_good_tokens = sum(good_tokens_per_rank)
loss = total_loss / total_good_tokens
loss = total_loss / max(total_good_tokens, 1)
```
In theory you could just average `losses_per_rank`, but the system supports variable sequence length so the last rank is likely to have a shorter sequence length and also use cases like SFT may have a variable number of tokens that contribute to the loss calculation, so it's best to compute a weighted loss.
@ -258,16 +258,16 @@ If your model isn't supported by Liger-kernel you can use our implementation, wh
output_unshard_dimension=0, # loss is a scalar
output_reduction="sum",
)
total_good_items = sum((shift_labels != -100).squeeze())
loss = total_loss_sum / total_good_items
total_good_items = (shift_labels != -100).squeeze().sum()
loss = total_loss_sum / max(total_good_items, 1)
# differentiable weighted per-shard-loss aggregation across ranks
losses_per_rank = torch.distributed.nn.functional.all_gather(loss, group=self.sp_group)
good_tokens = sum((shift_labels != -100).view(-1))
good_tokens = (shift_labels != -100).view(-1).sum()
good_tokens_per_rank = torch.distributed.nn.functional.all_gather(good_tokens, group=self.sp_group)
total_loss = sum(losses_per_rank[rank] * good_tokens_per_rank[rank] for rank in range(self.sp_world_size))
total_good_tokens = sum(good_tokens_per_rank)
loss = total_loss / total_good_tokens
loss = total_loss / max(total_good_tokens, 1)
return loss
```