Fix issue #5242 grad_norm and loss is nan (#7171)

This PR addresses a regression introduced in commit
[61daaa1](61daaa1ea2)
that affects gradient clipping when handling infinite values.

The modified NaN/Inf handling logic in total_norm calculation leads to
unexpected behavior:

Original logic
([v0.10.3](https://github.com/deepspeedai/DeepSpeed/blob/v0.10.3/deepspeed/runtime/zero/stage_1_and_2.py#L1233)):
Converted both NaN and Inf to -1 before entering unscale_and_clip_grads
Post-commit behavior: When total_norm is Inf, inf_or_nan.logical_not() *
total_norm produces NaN instead of 0, causing gradient clipping to fail

Here is a minimal reproducible example comparing gradient clipping
behavior across implementations.
```python
import torch
import numpy as np
import copy

def test(total_norm):
    test_old_deepspeed(total_norm)
    test_deepspeed(total_norm)
    test_torch(total_norm)
    test_deepspeed_fix(total_norm)

def test_old_deepspeed(total_norm_tensor):
    total_norm = copy.deepcopy(total_norm_tensor)
    # https://github.com/deepspeedai/DeepSpeed/blob/v0.10.3/deepspeed/runtime/zero/stage_1_and_2.py#L1233
    if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm:
        total_norm = torch.tensor(float(-1))
        
    # https://github.com/deepspeedai/DeepSpeed/blob/v0.10.3/deepspeed/runtime/zero/stage_1_and_2.py#L1848
    clip_grad = float(1.0)
    loss_scale = float(1.0)
    combined_scale = loss_scale
    clip = ((total_norm / loss_scale) + 1e-6) / clip_grad
    if clip > 1:
        combined_scale = clip * loss_scale
    print(f"old_deepspeed: {1. / combined_scale}")

def test_deepspeed(total_norm_tensor):
    total_norm = copy.deepcopy(total_norm_tensor)
    # https://github.com/deepspeedai/DeepSpeed/blob/v0.16.4/deepspeed/runtime/zero/stage_1_and_2.py#L1710
    norm_is_inf = total_norm.isinf()
    norm_is_nan = total_norm.isnan()
    inf_or_nan = norm_is_nan.logical_or(norm_is_inf)

    err = torch.tensor(-1.0, dtype=torch.float)
    total_norm = inf_or_nan * err + inf_or_nan.logical_not() * total_norm

    # https://github.com/deepspeedai/DeepSpeed/blob/v0.16.4/deepspeed/runtime/zero/stage_1_and_2.py#L1970
    clip_grad = float(1.0)
    loss_scale = float(1.0)
    clip = ((total_norm / loss_scale) + 1e-6) / clip_grad
    clip = torch.clamp(clip, min=1.0)
    combined_scale = clip * loss_scale
    print(f"test_deepspeed: {1. / combined_scale}")
    
def test_torch(total_norm_tensor):
    # https://github.com/pytorch/pytorch/blob/v2.6.0/torch/nn/utils/clip_grad.py#L155
    total_norm = copy.deepcopy(total_norm_tensor)
    max_norm = float(1.0)
    clip_coef = max_norm / (total_norm + 1e-6)
    clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
    print(f"torch: {clip_coef_clamped}")

def test_deepspeed_fix(total_norm_tensor):
    total_norm = copy.deepcopy(total_norm_tensor)
    if total_norm.isinf() or total_norm.isnan():
        total_norm = torch.tensor(-1.0, dtype=torch.float)

    # https://github.com/deepspeedai/DeepSpeed/blob/v0.16.4/deepspeed/runtime/zero/stage_1_and_2.py#L1970
    clip_grad = float(1.0)
    loss_scale = float(1.0)
    clip = ((total_norm / loss_scale) + 1e-6) / clip_grad
    clip = torch.clamp(clip, min=1.0)
    combined_scale = clip * loss_scale
    print(f"test_deepspeed_fix: {1. / combined_scale}")
    
if __name__ == '__main__':
    print("*****NAN*****")
    test(torch.tensor(float('nan')))
    print("*****INF*****")
    test(torch.tensor(float('inf')))
    print("*****positive*****")
    test(torch.tensor(float(2.0)))

```
Result:

![20250325165135](https://github.com/user-attachments/assets/bd32209d-14f6-4c21-8b57-f8bd94786fe2)

---------

Signed-off-by: yueyang.hyy <yueyang.hyy@alibaba-inc.com>
Co-authored-by: Hongwei Chen <33092912+hwchen2017@users.noreply.github.com>
This commit is contained in:
Glaceon-Hyy
2025-03-29 08:37:29 +08:00
committed by GitHub
parent b8cc1eb078
commit 1f706621f1

View File

@ -1727,10 +1727,10 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
norm_is_inf = total_norm.isinf()
norm_is_nan = total_norm.isnan()
inf_or_nan = norm_is_nan.logical_or(norm_is_inf)
err = torch.tensor(-1.0, device=self.device, dtype=torch.float)
total_norm = inf_or_nan * err + inf_or_nan.logical_not() * total_norm
if norm_is_inf or norm_is_nan:
total_norm = torch.tensor(-1.0, device=self.device, dtype=torch.float)
return total_norm
# creates a flat fused tensor from the tensor list starting at the first_offset
@ -1987,8 +1987,10 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
if self.clip_grad > 0.:
# norm is in fact norm*scale
clip = ((total_norm / self.loss_scale) + 1e-6) / self.clip_grad
clip = torch.clamp(clip, min=1.0)
combined_scale = clip * self.loss_scale
# handle total_norm invalid value -1
if clip > 1:
combined_scale = clip * self.loss_scale
for grad in grad_groups_flat:
if isinstance(grad, list):