mirror of
https://github.com/deepspeedai/DeepSpeed.git
synced 2025-10-20 15:33:51 +08:00
This PR addresses a regression introduced in commit
[61daaa1](61daaa1ea2
)
that affects gradient clipping when handling infinite values.
The modified NaN/Inf handling logic in total_norm calculation leads to
unexpected behavior:
Original logic
([v0.10.3](https://github.com/deepspeedai/DeepSpeed/blob/v0.10.3/deepspeed/runtime/zero/stage_1_and_2.py#L1233)):
Converted both NaN and Inf to -1 before entering unscale_and_clip_grads
Post-commit behavior: When total_norm is Inf, inf_or_nan.logical_not() *
total_norm produces NaN instead of 0, causing gradient clipping to fail
Here is a minimal reproducible example comparing gradient clipping
behavior across implementations.
```python
import torch
import numpy as np
import copy
def test(total_norm):
test_old_deepspeed(total_norm)
test_deepspeed(total_norm)
test_torch(total_norm)
test_deepspeed_fix(total_norm)
def test_old_deepspeed(total_norm_tensor):
total_norm = copy.deepcopy(total_norm_tensor)
# https://github.com/deepspeedai/DeepSpeed/blob/v0.10.3/deepspeed/runtime/zero/stage_1_and_2.py#L1233
if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm:
total_norm = torch.tensor(float(-1))
# https://github.com/deepspeedai/DeepSpeed/blob/v0.10.3/deepspeed/runtime/zero/stage_1_and_2.py#L1848
clip_grad = float(1.0)
loss_scale = float(1.0)
combined_scale = loss_scale
clip = ((total_norm / loss_scale) + 1e-6) / clip_grad
if clip > 1:
combined_scale = clip * loss_scale
print(f"old_deepspeed: {1. / combined_scale}")
def test_deepspeed(total_norm_tensor):
total_norm = copy.deepcopy(total_norm_tensor)
# https://github.com/deepspeedai/DeepSpeed/blob/v0.16.4/deepspeed/runtime/zero/stage_1_and_2.py#L1710
norm_is_inf = total_norm.isinf()
norm_is_nan = total_norm.isnan()
inf_or_nan = norm_is_nan.logical_or(norm_is_inf)
err = torch.tensor(-1.0, dtype=torch.float)
total_norm = inf_or_nan * err + inf_or_nan.logical_not() * total_norm
# https://github.com/deepspeedai/DeepSpeed/blob/v0.16.4/deepspeed/runtime/zero/stage_1_and_2.py#L1970
clip_grad = float(1.0)
loss_scale = float(1.0)
clip = ((total_norm / loss_scale) + 1e-6) / clip_grad
clip = torch.clamp(clip, min=1.0)
combined_scale = clip * loss_scale
print(f"test_deepspeed: {1. / combined_scale}")
def test_torch(total_norm_tensor):
# https://github.com/pytorch/pytorch/blob/v2.6.0/torch/nn/utils/clip_grad.py#L155
total_norm = copy.deepcopy(total_norm_tensor)
max_norm = float(1.0)
clip_coef = max_norm / (total_norm + 1e-6)
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
print(f"torch: {clip_coef_clamped}")
def test_deepspeed_fix(total_norm_tensor):
total_norm = copy.deepcopy(total_norm_tensor)
if total_norm.isinf() or total_norm.isnan():
total_norm = torch.tensor(-1.0, dtype=torch.float)
# https://github.com/deepspeedai/DeepSpeed/blob/v0.16.4/deepspeed/runtime/zero/stage_1_and_2.py#L1970
clip_grad = float(1.0)
loss_scale = float(1.0)
clip = ((total_norm / loss_scale) + 1e-6) / clip_grad
clip = torch.clamp(clip, min=1.0)
combined_scale = clip * loss_scale
print(f"test_deepspeed_fix: {1. / combined_scale}")
if __name__ == '__main__':
print("*****NAN*****")
test(torch.tensor(float('nan')))
print("*****INF*****")
test(torch.tensor(float('inf')))
print("*****positive*****")
test(torch.tensor(float(2.0)))
```
Result:

---------
Signed-off-by: yueyang.hyy <yueyang.hyy@alibaba-inc.com>
Co-authored-by: Hongwei Chen <33092912+hwchen2017@users.noreply.github.com>
This commit is contained in:
@ -1727,10 +1727,10 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
|
||||
|
||||
norm_is_inf = total_norm.isinf()
|
||||
norm_is_nan = total_norm.isnan()
|
||||
inf_or_nan = norm_is_nan.logical_or(norm_is_inf)
|
||||
|
||||
err = torch.tensor(-1.0, device=self.device, dtype=torch.float)
|
||||
total_norm = inf_or_nan * err + inf_or_nan.logical_not() * total_norm
|
||||
if norm_is_inf or norm_is_nan:
|
||||
total_norm = torch.tensor(-1.0, device=self.device, dtype=torch.float)
|
||||
|
||||
return total_norm
|
||||
|
||||
# creates a flat fused tensor from the tensor list starting at the first_offset
|
||||
@ -1987,8 +1987,10 @@ class DeepSpeedZeroOptimizer(ZeROOptimizer):
|
||||
if self.clip_grad > 0.:
|
||||
# norm is in fact norm*scale
|
||||
clip = ((total_norm / self.loss_scale) + 1e-6) / self.clip_grad
|
||||
clip = torch.clamp(clip, min=1.0)
|
||||
combined_scale = clip * self.loss_scale
|
||||
|
||||
# handle total_norm invalid value -1
|
||||
if clip > 1:
|
||||
combined_scale = clip * self.loss_scale
|
||||
|
||||
for grad in grad_groups_flat:
|
||||
if isinstance(grad, list):
|
||||
|
Reference in New Issue
Block a user