mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add APIs to separate norm calculation and gradient scaling in nn.utils.clip_grad_norm_
(#139662)
Fixes https://github.com/pytorch/pytorch/issues/139467 Refactor `nn.utils.clip_grad_norm_` into `nn.utils.get_total_norm` and then `nn.utils.clip_grads_with_norm_` . `clip_grad_norm_` now calls into these two new ops, `get_total_norm` is generalized (rather than `get_grad_norm` due to the discussion on the issue from @awgu) Pull Request resolved: https://github.com/pytorch/pytorch/pull/139662 Approved by: https://github.com/H-Huang
This commit is contained in:
committed by
PyTorch MergeBot
parent
09ba38c4b7
commit
2ee91db03d
@ -373,6 +373,8 @@ Utility functions to clip parameter gradients.
|
||||
clip_grad_norm_
|
||||
clip_grad_norm
|
||||
clip_grad_value_
|
||||
get_total_norm
|
||||
clip_grads_with_norm_
|
||||
|
||||
Utility functions to flatten and unflatten Module parameters to and from a single vector.
|
||||
|
||||
|
@ -22,7 +22,7 @@ import torch.backends.cudnn as cudnn
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.nn.utils.rnn as rnn_utils
|
||||
from torch.nn.utils import clip_grad_norm_, clip_grad_value_
|
||||
from torch.nn.utils import clip_grad_norm_, clip_grad_value_, clip_grads_with_norm_, get_total_norm
|
||||
from torch.nn.utils import parameters_to_vector, vector_to_parameters
|
||||
from torch.nn.utils.fusion import fuse_conv_bn_weights
|
||||
from torch.nn.utils.fusion import fuse_linear_bn_weights
|
||||
@ -12820,6 +12820,20 @@ if __name__ == '__main__':
|
||||
self.assertLessEqual(norm_after, norm_before)
|
||||
compare_scaling(grads)
|
||||
|
||||
# decomposed APIs should behave as expected
|
||||
grads = torch.arange(1., 101, device=device).view(10, 10), torch.ones(10, device=device).div(1000)
|
||||
for p, g in zip(l.parameters(), grads):
|
||||
p._grad = g.clone().view_as(p)
|
||||
norm_before = compute_norm(norm_type)
|
||||
grads = [p.grad for p in l.parameters()]
|
||||
total_norm = get_total_norm(grads, norm_type=norm_type, foreach=foreach)
|
||||
clip_grads_with_norm_(l.parameters(), max_norm, total_norm, foreach=foreach)
|
||||
norm_after = compute_norm(norm_type)
|
||||
self.assertEqual(total_norm, norm_before)
|
||||
self.assertEqual(norm_after, max_norm)
|
||||
self.assertLessEqual(norm_after, norm_before)
|
||||
compare_scaling(grads)
|
||||
|
||||
# Small gradients should be left unchanged
|
||||
grads = torch.rand(10, 10, device=device).div(10000), torch.ones(10, device=device).div(500)
|
||||
for p, g in zip(l.parameters(), grads):
|
||||
|
@ -1,5 +1,11 @@
|
||||
from . import parametrizations, rnn, stateless
|
||||
from .clip_grad import clip_grad_norm, clip_grad_norm_, clip_grad_value_
|
||||
from .clip_grad import (
|
||||
_clip_grads_with_norm_ as clip_grads_with_norm_,
|
||||
_get_total_norm as get_total_norm,
|
||||
clip_grad_norm,
|
||||
clip_grad_norm_,
|
||||
clip_grad_value_,
|
||||
)
|
||||
from .convert_parameters import parameters_to_vector, vector_to_parameters
|
||||
from .fusion import (
|
||||
fuse_conv_bn_eval,
|
||||
@ -19,6 +25,7 @@ from .weight_norm import remove_weight_norm, weight_norm
|
||||
__all__ = [
|
||||
"clip_grad_norm",
|
||||
"clip_grad_norm_",
|
||||
"clip_grads_with_norm_",
|
||||
"clip_grad_value_",
|
||||
"convert_conv2d_weight_memory_format",
|
||||
"convert_conv3d_weight_memory_format",
|
||||
@ -26,6 +33,7 @@ __all__ = [
|
||||
"fuse_conv_bn_weights",
|
||||
"fuse_linear_bn_eval",
|
||||
"fuse_linear_bn_weights",
|
||||
"get_total_norm",
|
||||
"parameters_to_vector",
|
||||
"parametrizations",
|
||||
"remove_spectral_norm",
|
||||
|
@ -13,7 +13,11 @@ from torch.utils._foreach_utils import (
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["clip_grad_norm_", "clip_grad_norm", "clip_grad_value_"]
|
||||
__all__ = [
|
||||
"clip_grad_norm_",
|
||||
"clip_grad_norm",
|
||||
"clip_grad_value_",
|
||||
]
|
||||
|
||||
|
||||
_tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]]
|
||||
@ -33,6 +37,141 @@ def _no_grad(func):
|
||||
return _no_grad_wrapper
|
||||
|
||||
|
||||
@_no_grad
|
||||
def _get_total_norm(
|
||||
tensors: _tensor_or_tensors,
|
||||
norm_type: float = 2.0,
|
||||
error_if_nonfinite: bool = False,
|
||||
foreach: Optional[bool] = None,
|
||||
) -> torch.Tensor:
|
||||
r"""Compute the norm of an iterable of tensors.
|
||||
|
||||
The norm is computed over the norms of the individual tensors, as if the norms of
|
||||
the individual tensors were concatenated into a single vector.
|
||||
|
||||
Args:
|
||||
tensors (Iterable[Tensor] or Tensor): an iterable of Tensors or a
|
||||
single Tensor that will be normalized
|
||||
norm_type (float): type of the used p-norm. Can be ``'inf'`` for
|
||||
infinity norm.
|
||||
error_if_nonfinite (bool): if True, an error is thrown if the total
|
||||
norm of :attr:`tensors` is ``nan``, ``inf``, or ``-inf``.
|
||||
Default: ``False``
|
||||
foreach (bool): use the faster foreach-based implementation.
|
||||
If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently
|
||||
fall back to the slow implementation for other device types.
|
||||
Default: ``None``
|
||||
|
||||
Returns:
|
||||
Total norm of the tensors (viewed as a single vector).
|
||||
"""
|
||||
if isinstance(tensors, torch.Tensor):
|
||||
tensors = [tensors]
|
||||
else:
|
||||
tensors = list(tensors)
|
||||
norm_type = float(norm_type)
|
||||
if len(tensors) == 0:
|
||||
return torch.tensor(0.0)
|
||||
first_device = tensors[0].device
|
||||
grouped_tensors: Dict[
|
||||
Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], List[int]]
|
||||
] = _group_tensors_by_device_and_dtype(
|
||||
[tensors] # type: ignore[list-item]
|
||||
) # type: ignore[assignment]
|
||||
|
||||
norms: List[Tensor] = []
|
||||
for (device, _), ([device_tensors], _) in grouped_tensors.items(): # type: ignore[assignment]
|
||||
if (foreach is None and _has_foreach_support(device_tensors, device)) or (
|
||||
foreach and _device_has_foreach_support(device)
|
||||
):
|
||||
norms.extend(torch._foreach_norm(device_tensors, norm_type))
|
||||
elif foreach:
|
||||
raise RuntimeError(
|
||||
f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
|
||||
)
|
||||
else:
|
||||
norms.extend(
|
||||
[torch.linalg.vector_norm(g, norm_type) for g in device_tensors]
|
||||
)
|
||||
|
||||
total_norm = torch.linalg.vector_norm(
|
||||
torch.stack([norm.to(first_device) for norm in norms]), norm_type
|
||||
)
|
||||
|
||||
if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()):
|
||||
raise RuntimeError(
|
||||
f"The total norm of order {norm_type} for gradients from "
|
||||
"`parameters` is non-finite, so it cannot be clipped. To disable "
|
||||
"this error and scale the gradients by the non-finite norm anyway, "
|
||||
"set `error_if_nonfinite=False`"
|
||||
)
|
||||
return total_norm
|
||||
|
||||
|
||||
@_no_grad
|
||||
def _clip_grads_with_norm_(
|
||||
parameters: _tensor_or_tensors,
|
||||
max_norm: float,
|
||||
total_norm: torch.Tensor,
|
||||
foreach: Optional[bool] = None,
|
||||
) -> None:
|
||||
r"""Scale the gradients of an iterable of parameters given a pre-calculated total norm and desired max norm.
|
||||
|
||||
The gradients will be scaled by the following calculation
|
||||
|
||||
.. math::
|
||||
grad = grad * \frac{max\_norm}{total\_norm + 1e-6}
|
||||
|
||||
Gradients are modified in-place.
|
||||
|
||||
This function is equivalent to :func:`torch.nn.utils.clip_grad_norm_` with a pre-calculated
|
||||
total norm.
|
||||
|
||||
Args:
|
||||
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
|
||||
single Tensor that will have gradients normalized
|
||||
max_norm (float): max norm of the gradients
|
||||
total_norm (Tensor): total norm of the gradients to use for clipping
|
||||
foreach (bool): use the faster foreach-based implementation.
|
||||
If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently
|
||||
fall back to the slow implementation for other device types.
|
||||
Default: ``None``
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if isinstance(parameters, torch.Tensor):
|
||||
parameters = [parameters]
|
||||
grads = [p.grad for p in parameters if p.grad is not None]
|
||||
max_norm = float(max_norm)
|
||||
if len(grads) == 0:
|
||||
return
|
||||
grouped_grads: Dict[
|
||||
Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], List[int]]
|
||||
] = _group_tensors_by_device_and_dtype(
|
||||
[grads]
|
||||
) # type: ignore[assignment]
|
||||
|
||||
clip_coef = max_norm / (total_norm + 1e-6)
|
||||
# Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so
|
||||
# avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization
|
||||
# when the gradients do not reside in CPU memory.
|
||||
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
|
||||
for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment]
|
||||
if (foreach is None and _has_foreach_support(device_grads, device)) or (
|
||||
foreach and _device_has_foreach_support(device)
|
||||
):
|
||||
torch._foreach_mul_(device_grads, clip_coef_clamped.to(device))
|
||||
elif foreach:
|
||||
raise RuntimeError(
|
||||
f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
|
||||
)
|
||||
else:
|
||||
clip_coef_clamped_device = clip_coef_clamped.to(device)
|
||||
for g in device_grads:
|
||||
g.mul_(clip_coef_clamped_device)
|
||||
|
||||
|
||||
@_no_grad
|
||||
def clip_grad_norm_(
|
||||
parameters: _tensor_or_tensors,
|
||||
@ -47,6 +186,9 @@ def clip_grad_norm_(
|
||||
as if the norms of the individual gradients were concatenated into a single vector.
|
||||
Gradients are modified in-place.
|
||||
|
||||
This function is equivalent to :func:`torch.nn.utils.get_total_norm` followed by
|
||||
:func:`torch.nn.utils.clip_grads_with_norm_` with the ``total_norm`` returned by ``get_total_norm``.
|
||||
|
||||
Args:
|
||||
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
|
||||
single Tensor that will have gradients normalized
|
||||
@ -66,61 +208,12 @@ def clip_grad_norm_(
|
||||
"""
|
||||
if isinstance(parameters, torch.Tensor):
|
||||
parameters = [parameters]
|
||||
else:
|
||||
# prevent generators from being exhausted
|
||||
parameters = list(parameters)
|
||||
grads = [p.grad for p in parameters if p.grad is not None]
|
||||
max_norm = float(max_norm)
|
||||
norm_type = float(norm_type)
|
||||
if len(grads) == 0:
|
||||
return torch.tensor(0.0)
|
||||
first_device = grads[0].device
|
||||
grouped_grads: Dict[
|
||||
Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], List[int]]
|
||||
] = _group_tensors_by_device_and_dtype(
|
||||
[grads]
|
||||
) # type: ignore[assignment]
|
||||
|
||||
norms: List[Tensor] = []
|
||||
for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment]
|
||||
if (foreach is None and _has_foreach_support(device_grads, device)) or (
|
||||
foreach and _device_has_foreach_support(device)
|
||||
):
|
||||
norms.extend(torch._foreach_norm(device_grads, norm_type))
|
||||
elif foreach:
|
||||
raise RuntimeError(
|
||||
f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
|
||||
)
|
||||
else:
|
||||
norms.extend([torch.linalg.vector_norm(g, norm_type) for g in device_grads])
|
||||
|
||||
total_norm = torch.linalg.vector_norm(
|
||||
torch.stack([norm.to(first_device) for norm in norms]), norm_type
|
||||
)
|
||||
|
||||
if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()):
|
||||
raise RuntimeError(
|
||||
f"The total norm of order {norm_type} for gradients from "
|
||||
"`parameters` is non-finite, so it cannot be clipped. To disable "
|
||||
"this error and scale the gradients by the non-finite norm anyway, "
|
||||
"set `error_if_nonfinite=False`"
|
||||
)
|
||||
clip_coef = max_norm / (total_norm + 1e-6)
|
||||
# Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so
|
||||
# avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization
|
||||
# when the gradients do not reside in CPU memory.
|
||||
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
|
||||
for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment]
|
||||
if (foreach is None and _has_foreach_support(device_grads, device)) or (
|
||||
foreach and _device_has_foreach_support(device)
|
||||
):
|
||||
torch._foreach_mul_(device_grads, clip_coef_clamped.to(device))
|
||||
elif foreach:
|
||||
raise RuntimeError(
|
||||
f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
|
||||
)
|
||||
else:
|
||||
clip_coef_clamped_device = clip_coef_clamped.to(device)
|
||||
for g in device_grads:
|
||||
g.mul_(clip_coef_clamped_device)
|
||||
|
||||
total_norm = _get_total_norm(grads, norm_type, error_if_nonfinite, foreach)
|
||||
_clip_grads_with_norm_(parameters, max_norm, total_norm, foreach)
|
||||
return total_norm
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user