[BE][Ez]: Fully type nn.utils.clip_grad (#154801)

Full types clip_grad and exposed typing annotations that were hidden by a bad decorator

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154801
Approved by: https://github.com/jansel
This commit is contained in:
Aaron Gokaslan
2025-05-31 23:06:41 +00:00
committed by PyTorch MergeBot
parent dbad6d71c7
commit 9ce2732b68

View File

@ -1,11 +1,9 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import functools import functools
import types import types
import typing import typing
import warnings import warnings
from typing import cast, Optional, Union from typing import Callable, cast, Optional, TypeVar, Union
from typing_extensions import deprecated from typing_extensions import deprecated, ParamSpec, TypeAlias
import torch import torch
from torch import Tensor from torch import Tensor
@ -23,19 +21,22 @@ __all__ = [
] ]
_tensor_or_tensors = Union[ _TensorOrTensors: TypeAlias = Union[
torch.Tensor, torch.Tensor,
typing.Iterable[torch.Tensor], # noqa: UP006 - needed until XLA's patch is updated typing.Iterable[torch.Tensor], # noqa: UP006 - needed until XLA's patch is updated
] ]
_P = ParamSpec("_P")
_R = TypeVar("_R")
def _no_grad(func):
def _no_grad(func: Callable[_P, _R]) -> Callable[_P, _R]:
""" """
This wrapper is needed to avoid a circular import when using @torch.no_grad on the exposed functions This wrapper is needed to avoid a circular import when using @torch.no_grad on the exposed functions
clip_grad_norm_ and clip_grad_value_ themselves. clip_grad_norm_ and clip_grad_value_ themselves.
""" """
def _no_grad_wrapper(*args, **kwargs): def _no_grad_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
with torch.no_grad(): with torch.no_grad():
return func(*args, **kwargs) return func(*args, **kwargs)
@ -45,7 +46,7 @@ def _no_grad(func):
@_no_grad @_no_grad
def _get_total_norm( def _get_total_norm(
tensors: _tensor_or_tensors, tensors: _TensorOrTensors,
norm_type: float = 2.0, norm_type: float = 2.0,
error_if_nonfinite: bool = False, error_if_nonfinite: bool = False,
foreach: Optional[bool] = None, foreach: Optional[bool] = None,
@ -116,7 +117,7 @@ def _get_total_norm(
@_no_grad @_no_grad
def _clip_grads_with_norm_( def _clip_grads_with_norm_(
parameters: _tensor_or_tensors, parameters: _TensorOrTensors,
max_norm: float, max_norm: float,
total_norm: torch.Tensor, total_norm: torch.Tensor,
foreach: Optional[bool] = None, foreach: Optional[bool] = None,
@ -180,7 +181,7 @@ def _clip_grads_with_norm_(
@_no_grad @_no_grad
def clip_grad_norm_( def clip_grad_norm_(
parameters: _tensor_or_tensors, parameters: _TensorOrTensors,
max_norm: float, max_norm: float,
norm_type: float = 2.0, norm_type: float = 2.0,
error_if_nonfinite: bool = False, error_if_nonfinite: bool = False,
@ -235,7 +236,7 @@ def clip_grad_norm_(
category=FutureWarning, category=FutureWarning,
) )
def clip_grad_norm( def clip_grad_norm(
parameters: _tensor_or_tensors, parameters: _TensorOrTensors,
max_norm: float, max_norm: float,
norm_type: float = 2.0, norm_type: float = 2.0,
error_if_nonfinite: bool = False, error_if_nonfinite: bool = False,
@ -252,7 +253,7 @@ def clip_grad_norm(
@_no_grad @_no_grad
def clip_grad_value_( def clip_grad_value_(
parameters: _tensor_or_tensors, parameters: _TensorOrTensors,
clip_value: float, clip_value: float,
foreach: Optional[bool] = None, foreach: Optional[bool] = None,
) -> None: ) -> None: