mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add hint message when parameters is empty in clip_grad_norm_ (#151529)
Fixes #148259 ## Changes - Add print warning message when `parameters` generator exhausted ## Test Result ### print warning ```python import torch import torch.nn as nn import torch.optim as optim class SimpleModel(nn.Module): def __init__(self): super(SimpleModel, self).__init__() self.fc = nn.Linear(10, 1) def forward(self, x): return self.fc(x) model = SimpleModel() criterion = nn.MSELoss() optimizer = optim.SGD(model.parameters(), lr=0.01) inputs = torch.randn(16, 10) targets = torch.randn(16, 1) outputs = model(inputs) loss = criterion(outputs, targets) optimizer.zero_grad() loss.backward() params_to_clip = model.parameters() for p in params_to_clip: print(p.shape) max_norm = 1.0 norm_type = 2.0 total_norm = nn.utils.clip_grad_norm_(params_to_clip, max_norm, norm_type) print(f"total_norm: {total_norm}") ``` ```bash /home/zong/code/pytorch/torch/nn/utils/clip_grad.py:222: UserWarning: `parameters` is an empty generator, no gradient clipping will occur. warnings.warn( total_norm: 0.0 ``` ### UT ```bash pytest test/test_nn.py -k test_clip_grad_norm ```  Pull Request resolved: https://github.com/pytorch/pytorch/pull/151529 Approved by: https://github.com/jbschlosser
This commit is contained in:
committed by
PyTorch MergeBot
parent
40e6ca24ef
commit
f12d8d60b1
@ -13117,6 +13117,16 @@ if __name__ == '__main__':
|
||||
clip_grad_norm_([p2], max_norm, norm_type=norm_type, foreach=foreach)
|
||||
self.assertEqual(p1.grad, p2.grad)
|
||||
|
||||
# Should warning when parameters generator exhausted
|
||||
params = l.parameters()
|
||||
for p in params:
|
||||
pass
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
clip_grad_norm_(params, max_norm, norm_type=norm_type, foreach=foreach)
|
||||
self.assertEqual(len(w), 1)
|
||||
self.assertEqual(str(w[0].message), "`parameters` is an empty generator, no gradient clipping will occur.")
|
||||
|
||||
# reference issue: https://github.com/pytorch/pytorch/issues/111484
|
||||
@onlyCUDA
|
||||
@largeTensorTest("42GB", "cuda")
|
||||
|
@ -1,7 +1,9 @@
|
||||
# mypy: allow-untyped-decorators
|
||||
# mypy: allow-untyped-defs
|
||||
import functools
|
||||
import types
|
||||
import typing
|
||||
import warnings
|
||||
from typing import cast, Optional, Union
|
||||
from typing_extensions import deprecated
|
||||
|
||||
@ -213,8 +215,14 @@ def clip_grad_norm_(
|
||||
if isinstance(parameters, torch.Tensor):
|
||||
parameters = [parameters]
|
||||
else:
|
||||
is_generator = isinstance(parameters, types.GeneratorType)
|
||||
# prevent generators from being exhausted
|
||||
parameters = list(parameters)
|
||||
if is_generator and len(parameters) == 0:
|
||||
warnings.warn(
|
||||
"`parameters` is an empty generator, no gradient clipping will occur.",
|
||||
stacklevel=3,
|
||||
)
|
||||
grads = [p.grad for p in parameters if p.grad is not None]
|
||||
total_norm = _get_total_norm(grads, norm_type, error_if_nonfinite, foreach)
|
||||
_clip_grads_with_norm_(parameters, max_norm, total_norm, foreach)
|
||||
|
Reference in New Issue
Block a user