mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
[alstn tutorial] support bs>1 (#7550)
Edit tutorial's demo code to support bs>1 and prevent div by zero
This commit is contained in:
@ -116,11 +116,11 @@ for iter, batch in enumerate(dl):
|
||||
# differentiable weighted per-shard-loss aggregation across ranks
|
||||
losses_per_rank = torch.distributed.nn.functional.all_gather(loss, group=sp_group)
|
||||
# special dealing with SFT that has prompt tokens that aren't used in loss computation
|
||||
good_tokens = sum((shift_labels != -100).view(-1))
|
||||
good_tokens = (shift_labels != -100).view(-1).sum()
|
||||
good_tokens_per_rank = torch.distributed.nn.functional.all_gather(good_tokens, group=sp_group)
|
||||
total_loss = sum(losses_per_rank[rank] * good_tokens_per_rank[rank] for rank in range(sp_world_size))
|
||||
total_good_tokens = sum(good_tokens_per_rank)
|
||||
loss = total_loss / total_good_tokens
|
||||
loss = total_loss / max(total_good_tokens, 1)
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
print(f"{iter}: {loss=}")
|
||||
@ -185,11 +185,11 @@ Since each rank processes a segment we need to average loss. To get the gradient
|
||||
# differentiable weighted per-shard-loss aggregation across ranks
|
||||
losses_per_rank = torch.distributed.nn.functional.all_gather(loss, group=sp_group)
|
||||
# special dealing with SFT that has prompt tokens that aren't used in loss computation
|
||||
good_tokens = sum((shift_labels != -100).view(-1))
|
||||
good_tokens = (shift_labels != -100).view(-1).sum()
|
||||
good_tokens_per_rank = torch.distributed.nn.functional.all_gather(good_tokens, group=sp_group)
|
||||
total_loss = sum(losses_per_rank[rank] * good_tokens_per_rank[rank] for rank in range(sp_world_size))
|
||||
total_good_tokens = sum(good_tokens_per_rank)
|
||||
loss = total_loss / total_good_tokens
|
||||
loss = total_loss / max(total_good_tokens, 1)
|
||||
```
|
||||
|
||||
In theory you could just average `losses_per_rank`, but the system supports variable sequence length so the last rank is likely to have a shorter sequence length and also use cases like SFT may have a variable number of tokens that contribute to the loss calculation, so it's best to compute a weighted loss.
|
||||
@ -258,16 +258,16 @@ If your model isn't supported by Liger-kernel you can use our implementation, wh
|
||||
output_unshard_dimension=0, # loss is a scalar
|
||||
output_reduction="sum",
|
||||
)
|
||||
total_good_items = sum((shift_labels != -100).squeeze())
|
||||
loss = total_loss_sum / total_good_items
|
||||
total_good_items = (shift_labels != -100).squeeze().sum()
|
||||
loss = total_loss_sum / max(total_good_items, 1)
|
||||
|
||||
# differentiable weighted per-shard-loss aggregation across ranks
|
||||
losses_per_rank = torch.distributed.nn.functional.all_gather(loss, group=self.sp_group)
|
||||
good_tokens = sum((shift_labels != -100).view(-1))
|
||||
good_tokens = (shift_labels != -100).view(-1).sum()
|
||||
good_tokens_per_rank = torch.distributed.nn.functional.all_gather(good_tokens, group=self.sp_group)
|
||||
total_loss = sum(losses_per_rank[rank] * good_tokens_per_rank[rank] for rank in range(self.sp_world_size))
|
||||
total_good_tokens = sum(good_tokens_per_rank)
|
||||
loss = total_loss / total_good_tokens
|
||||
loss = total_loss / max(total_good_tokens, 1)
|
||||
|
||||
return loss
|
||||
```
|
||||
|
Reference in New Issue
Block a user