Add hint message when parameters is empty in clip_grad_norm_ (#151529)

Fixes #148259

## Changes

- Add print warning message when `parameters` generator exhausted

## Test Result
### print warning
```python

import torch
import torch.nn as nn
import torch.optim as optim

class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(10, 1)

    def forward(self, x):
        return self.fc(x)

model = SimpleModel()
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

inputs = torch.randn(16, 10)
targets = torch.randn(16, 1)

outputs = model(inputs)
loss = criterion(outputs, targets)
optimizer.zero_grad()
loss.backward()

params_to_clip = model.parameters()

for p in params_to_clip:
    print(p.shape)

max_norm = 1.0
norm_type = 2.0
total_norm = nn.utils.clip_grad_norm_(params_to_clip, max_norm, norm_type)
print(f"total_norm: {total_norm}")
```

```bash
/home/zong/code/pytorch/torch/nn/utils/clip_grad.py:222: UserWarning: `parameters` is an empty generator, no gradient clipping will occur.
  warnings.warn(
total_norm: 0.0
```

### UT

```bash
pytest test/test_nn.py -k test_clip_grad_norm
```

![image](https://github.com/user-attachments/assets/0aa0f06c-e0a5-43cf-9a97-d7c2747c9180)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151529
Approved by: https://github.com/jbschlosser
This commit is contained in:
zeshengzong
2025-05-22 11:23:34 +00:00
committed by PyTorch MergeBot
parent 40e6ca24ef
commit f12d8d60b1
2 changed files with 18 additions and 0 deletions

View File

@ -13117,6 +13117,16 @@ if __name__ == '__main__':
clip_grad_norm_([p2], max_norm, norm_type=norm_type, foreach=foreach)
self.assertEqual(p1.grad, p2.grad)
# Should warning when parameters generator exhausted
params = l.parameters()
for p in params:
pass
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
clip_grad_norm_(params, max_norm, norm_type=norm_type, foreach=foreach)
self.assertEqual(len(w), 1)
self.assertEqual(str(w[0].message), "`parameters` is an empty generator, no gradient clipping will occur.")
# reference issue: https://github.com/pytorch/pytorch/issues/111484
@onlyCUDA
@largeTensorTest("42GB", "cuda")