Add APIs to separate norm calculation and gradient scaling in nn.utils.clip_grad_norm_ (#139662)

Fixes https://github.com/pytorch/pytorch/issues/139467

Refactor `nn.utils.clip_grad_norm_` into `nn.utils.get_total_norm` and then `nn.utils.clip_grads_with_norm_` . `clip_grad_norm_` now calls into these two new ops,

`get_total_norm` is generalized (rather than `get_grad_norm` due to the discussion on the issue from @awgu)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139662
Approved by: https://github.com/H-Huang
This commit is contained in:
Mikayla Gawarecki
2024-11-06 15:40:53 -08:00
committed by PyTorch MergeBot
parent 09ba38c4b7
commit 2ee91db03d
4 changed files with 174 additions and 57 deletions

View File

@ -373,6 +373,8 @@ Utility functions to clip parameter gradients.
clip_grad_norm_
clip_grad_norm
clip_grad_value_
get_total_norm
clip_grads_with_norm_
Utility functions to flatten and unflatten Module parameters to and from a single vector.

View File

@ -22,7 +22,7 @@ import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils.rnn as rnn_utils
from torch.nn.utils import clip_grad_norm_, clip_grad_value_
from torch.nn.utils import clip_grad_norm_, clip_grad_value_, clip_grads_with_norm_, get_total_norm
from torch.nn.utils import parameters_to_vector, vector_to_parameters
from torch.nn.utils.fusion import fuse_conv_bn_weights
from torch.nn.utils.fusion import fuse_linear_bn_weights
@ -12820,6 +12820,20 @@ if __name__ == '__main__':
self.assertLessEqual(norm_after, norm_before)
compare_scaling(grads)
# decomposed APIs should behave as expected
grads = torch.arange(1., 101, device=device).view(10, 10), torch.ones(10, device=device).div(1000)
for p, g in zip(l.parameters(), grads):
p._grad = g.clone().view_as(p)
norm_before = compute_norm(norm_type)
grads = [p.grad for p in l.parameters()]
total_norm = get_total_norm(grads, norm_type=norm_type, foreach=foreach)
clip_grads_with_norm_(l.parameters(), max_norm, total_norm, foreach=foreach)
norm_after = compute_norm(norm_type)
self.assertEqual(total_norm, norm_before)
self.assertEqual(norm_after, max_norm)
self.assertLessEqual(norm_after, norm_before)
compare_scaling(grads)
# Small gradients should be left unchanged
grads = torch.rand(10, 10, device=device).div(10000), torch.ones(10, device=device).div(500)
for p, g in zip(l.parameters(), grads):

View File

@ -1,5 +1,11 @@
from . import parametrizations, rnn, stateless
from .clip_grad import clip_grad_norm, clip_grad_norm_, clip_grad_value_
from .clip_grad import (
_clip_grads_with_norm_ as clip_grads_with_norm_,
_get_total_norm as get_total_norm,
clip_grad_norm,
clip_grad_norm_,
clip_grad_value_,
)
from .convert_parameters import parameters_to_vector, vector_to_parameters
from .fusion import (
fuse_conv_bn_eval,
@ -19,6 +25,7 @@ from .weight_norm import remove_weight_norm, weight_norm
__all__ = [
"clip_grad_norm",
"clip_grad_norm_",
"clip_grads_with_norm_",
"clip_grad_value_",
"convert_conv2d_weight_memory_format",
"convert_conv3d_weight_memory_format",
@ -26,6 +33,7 @@ __all__ = [
"fuse_conv_bn_weights",
"fuse_linear_bn_eval",
"fuse_linear_bn_weights",
"get_total_norm",
"parameters_to_vector",
"parametrizations",
"remove_spectral_norm",

View File

@ -13,7 +13,11 @@ from torch.utils._foreach_utils import (
)
__all__ = ["clip_grad_norm_", "clip_grad_norm", "clip_grad_value_"]
__all__ = [
"clip_grad_norm_",
"clip_grad_norm",
"clip_grad_value_",
]
_tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]]
@ -33,6 +37,141 @@ def _no_grad(func):
return _no_grad_wrapper
@_no_grad
def _get_total_norm(
tensors: _tensor_or_tensors,
norm_type: float = 2.0,
error_if_nonfinite: bool = False,
foreach: Optional[bool] = None,
) -> torch.Tensor:
r"""Compute the norm of an iterable of tensors.
The norm is computed over the norms of the individual tensors, as if the norms of
the individual tensors were concatenated into a single vector.
Args:
tensors (Iterable[Tensor] or Tensor): an iterable of Tensors or a
single Tensor that will be normalized
norm_type (float): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
error_if_nonfinite (bool): if True, an error is thrown if the total
norm of :attr:`tensors` is ``nan``, ``inf``, or ``-inf``.
Default: ``False``
foreach (bool): use the faster foreach-based implementation.
If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently
fall back to the slow implementation for other device types.
Default: ``None``
Returns:
Total norm of the tensors (viewed as a single vector).
"""
if isinstance(tensors, torch.Tensor):
tensors = [tensors]
else:
tensors = list(tensors)
norm_type = float(norm_type)
if len(tensors) == 0:
return torch.tensor(0.0)
first_device = tensors[0].device
grouped_tensors: Dict[
Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], List[int]]
] = _group_tensors_by_device_and_dtype(
[tensors] # type: ignore[list-item]
) # type: ignore[assignment]
norms: List[Tensor] = []
for (device, _), ([device_tensors], _) in grouped_tensors.items(): # type: ignore[assignment]
if (foreach is None and _has_foreach_support(device_tensors, device)) or (
foreach and _device_has_foreach_support(device)
):
norms.extend(torch._foreach_norm(device_tensors, norm_type))
elif foreach:
raise RuntimeError(
f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
)
else:
norms.extend(
[torch.linalg.vector_norm(g, norm_type) for g in device_tensors]
)
total_norm = torch.linalg.vector_norm(
torch.stack([norm.to(first_device) for norm in norms]), norm_type
)
if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()):
raise RuntimeError(
f"The total norm of order {norm_type} for gradients from "
"`parameters` is non-finite, so it cannot be clipped. To disable "
"this error and scale the gradients by the non-finite norm anyway, "
"set `error_if_nonfinite=False`"
)
return total_norm
@_no_grad
def _clip_grads_with_norm_(
parameters: _tensor_or_tensors,
max_norm: float,
total_norm: torch.Tensor,
foreach: Optional[bool] = None,
) -> None:
r"""Scale the gradients of an iterable of parameters given a pre-calculated total norm and desired max norm.
The gradients will be scaled by the following calculation
.. math::
grad = grad * \frac{max\_norm}{total\_norm + 1e-6}
Gradients are modified in-place.
This function is equivalent to :func:`torch.nn.utils.clip_grad_norm_` with a pre-calculated
total norm.
Args:
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
single Tensor that will have gradients normalized
max_norm (float): max norm of the gradients
total_norm (Tensor): total norm of the gradients to use for clipping
foreach (bool): use the faster foreach-based implementation.
If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently
fall back to the slow implementation for other device types.
Default: ``None``
Returns:
None
"""
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
grads = [p.grad for p in parameters if p.grad is not None]
max_norm = float(max_norm)
if len(grads) == 0:
return
grouped_grads: Dict[
Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], List[int]]
] = _group_tensors_by_device_and_dtype(
[grads]
) # type: ignore[assignment]
clip_coef = max_norm / (total_norm + 1e-6)
# Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so
# avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization
# when the gradients do not reside in CPU memory.
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment]
if (foreach is None and _has_foreach_support(device_grads, device)) or (
foreach and _device_has_foreach_support(device)
):
torch._foreach_mul_(device_grads, clip_coef_clamped.to(device))
elif foreach:
raise RuntimeError(
f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
)
else:
clip_coef_clamped_device = clip_coef_clamped.to(device)
for g in device_grads:
g.mul_(clip_coef_clamped_device)
@_no_grad
def clip_grad_norm_(
parameters: _tensor_or_tensors,
@ -47,6 +186,9 @@ def clip_grad_norm_(
as if the norms of the individual gradients were concatenated into a single vector.
Gradients are modified in-place.
This function is equivalent to :func:`torch.nn.utils.get_total_norm` followed by
:func:`torch.nn.utils.clip_grads_with_norm_` with the ``total_norm`` returned by ``get_total_norm``.
Args:
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
single Tensor that will have gradients normalized
@ -66,61 +208,12 @@ def clip_grad_norm_(
"""
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
else:
# prevent generators from being exhausted
parameters = list(parameters)
grads = [p.grad for p in parameters if p.grad is not None]
max_norm = float(max_norm)
norm_type = float(norm_type)
if len(grads) == 0:
return torch.tensor(0.0)
first_device = grads[0].device
grouped_grads: Dict[
Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], List[int]]
] = _group_tensors_by_device_and_dtype(
[grads]
) # type: ignore[assignment]
norms: List[Tensor] = []
for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment]
if (foreach is None and _has_foreach_support(device_grads, device)) or (
foreach and _device_has_foreach_support(device)
):
norms.extend(torch._foreach_norm(device_grads, norm_type))
elif foreach:
raise RuntimeError(
f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
)
else:
norms.extend([torch.linalg.vector_norm(g, norm_type) for g in device_grads])
total_norm = torch.linalg.vector_norm(
torch.stack([norm.to(first_device) for norm in norms]), norm_type
)
if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()):
raise RuntimeError(
f"The total norm of order {norm_type} for gradients from "
"`parameters` is non-finite, so it cannot be clipped. To disable "
"this error and scale the gradients by the non-finite norm anyway, "
"set `error_if_nonfinite=False`"
)
clip_coef = max_norm / (total_norm + 1e-6)
# Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so
# avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization
# when the gradients do not reside in CPU memory.
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment]
if (foreach is None and _has_foreach_support(device_grads, device)) or (
foreach and _device_has_foreach_support(device)
):
torch._foreach_mul_(device_grads, clip_coef_clamped.to(device))
elif foreach:
raise RuntimeError(
f"foreach=True was passed, but can't use the foreach API on {device.type} tensors"
)
else:
clip_coef_clamped_device = clip_coef_clamped.to(device)
for g in device_grads:
g.mul_(clip_coef_clamped_device)
total_norm = _get_total_norm(grads, norm_type, error_if_nonfinite, foreach)
_clip_grads_with_norm_(parameters, max_norm, total_norm, foreach)
return total_norm