diff --git a/docs/_tutorials/ulysses-alst-sequence-parallelism.md b/docs/_tutorials/ulysses-alst-sequence-parallelism.md index a7bf0e442..44d068e58 100644 --- a/docs/_tutorials/ulysses-alst-sequence-parallelism.md +++ b/docs/_tutorials/ulysses-alst-sequence-parallelism.md @@ -116,11 +116,11 @@ for iter, batch in enumerate(dl): # differentiable weighted per-shard-loss aggregation across ranks losses_per_rank = torch.distributed.nn.functional.all_gather(loss, group=sp_group) # special dealing with SFT that has prompt tokens that aren't used in loss computation - good_tokens = sum((shift_labels != -100).view(-1)) + good_tokens = (shift_labels != -100).view(-1).sum() good_tokens_per_rank = torch.distributed.nn.functional.all_gather(good_tokens, group=sp_group) total_loss = sum(losses_per_rank[rank] * good_tokens_per_rank[rank] for rank in range(sp_world_size)) total_good_tokens = sum(good_tokens_per_rank) - loss = total_loss / total_good_tokens + loss = total_loss / max(total_good_tokens, 1) if dist.get_rank() == 0: print(f"{iter}: {loss=}") @@ -185,11 +185,11 @@ Since each rank processes a segment we need to average loss. To get the gradient # differentiable weighted per-shard-loss aggregation across ranks losses_per_rank = torch.distributed.nn.functional.all_gather(loss, group=sp_group) # special dealing with SFT that has prompt tokens that aren't used in loss computation - good_tokens = sum((shift_labels != -100).view(-1)) + good_tokens = (shift_labels != -100).view(-1).sum() good_tokens_per_rank = torch.distributed.nn.functional.all_gather(good_tokens, group=sp_group) total_loss = sum(losses_per_rank[rank] * good_tokens_per_rank[rank] for rank in range(sp_world_size)) total_good_tokens = sum(good_tokens_per_rank) - loss = total_loss / total_good_tokens + loss = total_loss / max(total_good_tokens, 1) ``` In theory you could just average `losses_per_rank`, but the system supports variable sequence length so the last rank is likely to have a shorter sequence length and also use cases like SFT may have a variable number of tokens that contribute to the loss calculation, so it's best to compute a weighted loss. @@ -258,16 +258,16 @@ If your model isn't supported by Liger-kernel you can use our implementation, wh output_unshard_dimension=0, # loss is a scalar output_reduction="sum", ) - total_good_items = sum((shift_labels != -100).squeeze()) - loss = total_loss_sum / total_good_items + total_good_items = (shift_labels != -100).squeeze().sum() + loss = total_loss_sum / max(total_good_items, 1) # differentiable weighted per-shard-loss aggregation across ranks losses_per_rank = torch.distributed.nn.functional.all_gather(loss, group=self.sp_group) - good_tokens = sum((shift_labels != -100).view(-1)) + good_tokens = (shift_labels != -100).view(-1).sum() good_tokens_per_rank = torch.distributed.nn.functional.all_gather(good_tokens, group=self.sp_group) total_loss = sum(losses_per_rank[rank] * good_tokens_per_rank[rank] for rank in range(self.sp_world_size)) total_good_tokens = sum(good_tokens_per_rank) - loss = total_loss / total_good_tokens + loss = total_loss / max(total_good_tokens, 1) return loss ```