[BE][Ez]: Fully type nn.utils.clip_grad (#154801)

Full types clip_grad and exposed typing annotations that were hidden by a bad decorator

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154801
Approved by: https://github.com/jansel
This commit is contained in:
Aaron Gokaslan
2025-05-31 23:06:41 +00:00
committed by PyTorch MergeBot
parent dbad6d71c7
commit 9ce2732b68

View File

@ -1,11 +1,9 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import functools
import types
import typing
import warnings
from typing import cast, Optional, Union
from typing_extensions import deprecated
from typing import Callable, cast, Optional, TypeVar, Union
from typing_extensions import deprecated, ParamSpec, TypeAlias
import torch
from torch import Tensor
@ -23,19 +21,22 @@ __all__ = [
]
_tensor_or_tensors = Union[
_TensorOrTensors: TypeAlias = Union[
torch.Tensor,
typing.Iterable[torch.Tensor], # noqa: UP006 - needed until XLA's patch is updated
]
_P = ParamSpec("_P")
_R = TypeVar("_R")
def _no_grad(func):
def _no_grad(func: Callable[_P, _R]) -> Callable[_P, _R]:
"""
This wrapper is needed to avoid a circular import when using @torch.no_grad on the exposed functions
clip_grad_norm_ and clip_grad_value_ themselves.
"""
def _no_grad_wrapper(*args, **kwargs):
def _no_grad_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
with torch.no_grad():
return func(*args, **kwargs)
@ -45,7 +46,7 @@ def _no_grad(func):
@_no_grad
def _get_total_norm(
tensors: _tensor_or_tensors,
tensors: _TensorOrTensors,
norm_type: float = 2.0,
error_if_nonfinite: bool = False,
foreach: Optional[bool] = None,
@ -116,7 +117,7 @@ def _get_total_norm(
@_no_grad
def _clip_grads_with_norm_(
parameters: _tensor_or_tensors,
parameters: _TensorOrTensors,
max_norm: float,
total_norm: torch.Tensor,
foreach: Optional[bool] = None,
@ -180,7 +181,7 @@ def _clip_grads_with_norm_(
@_no_grad
def clip_grad_norm_(
parameters: _tensor_or_tensors,
parameters: _TensorOrTensors,
max_norm: float,
norm_type: float = 2.0,
error_if_nonfinite: bool = False,
@ -235,7 +236,7 @@ def clip_grad_norm_(
category=FutureWarning,
)
def clip_grad_norm(
parameters: _tensor_or_tensors,
parameters: _TensorOrTensors,
max_norm: float,
norm_type: float = 2.0,
error_if_nonfinite: bool = False,
@ -252,7 +253,7 @@ def clip_grad_norm(
@_no_grad
def clip_grad_value_(
parameters: _tensor_or_tensors,
parameters: _TensorOrTensors,
clip_value: float,
foreach: Optional[bool] = None,
) -> None: