mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[BE][Ez]: Fully type nn.utils.clip_grad (#154801)
Full types clip_grad and exposed typing annotations that were hidden by a bad decorator Pull Request resolved: https://github.com/pytorch/pytorch/pull/154801 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
dbad6d71c7
commit
9ce2732b68
@ -1,11 +1,9 @@
|
||||
# mypy: allow-untyped-decorators
|
||||
# mypy: allow-untyped-defs
|
||||
import functools
|
||||
import types
|
||||
import typing
|
||||
import warnings
|
||||
from typing import cast, Optional, Union
|
||||
from typing_extensions import deprecated
|
||||
from typing import Callable, cast, Optional, TypeVar, Union
|
||||
from typing_extensions import deprecated, ParamSpec, TypeAlias
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
@ -23,19 +21,22 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
_tensor_or_tensors = Union[
|
||||
_TensorOrTensors: TypeAlias = Union[
|
||||
torch.Tensor,
|
||||
typing.Iterable[torch.Tensor], # noqa: UP006 - needed until XLA's patch is updated
|
||||
]
|
||||
|
||||
_P = ParamSpec("_P")
|
||||
_R = TypeVar("_R")
|
||||
|
||||
def _no_grad(func):
|
||||
|
||||
def _no_grad(func: Callable[_P, _R]) -> Callable[_P, _R]:
|
||||
"""
|
||||
This wrapper is needed to avoid a circular import when using @torch.no_grad on the exposed functions
|
||||
clip_grad_norm_ and clip_grad_value_ themselves.
|
||||
"""
|
||||
|
||||
def _no_grad_wrapper(*args, **kwargs):
|
||||
def _no_grad_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
||||
with torch.no_grad():
|
||||
return func(*args, **kwargs)
|
||||
|
||||
@ -45,7 +46,7 @@ def _no_grad(func):
|
||||
|
||||
@_no_grad
|
||||
def _get_total_norm(
|
||||
tensors: _tensor_or_tensors,
|
||||
tensors: _TensorOrTensors,
|
||||
norm_type: float = 2.0,
|
||||
error_if_nonfinite: bool = False,
|
||||
foreach: Optional[bool] = None,
|
||||
@ -116,7 +117,7 @@ def _get_total_norm(
|
||||
|
||||
@_no_grad
|
||||
def _clip_grads_with_norm_(
|
||||
parameters: _tensor_or_tensors,
|
||||
parameters: _TensorOrTensors,
|
||||
max_norm: float,
|
||||
total_norm: torch.Tensor,
|
||||
foreach: Optional[bool] = None,
|
||||
@ -180,7 +181,7 @@ def _clip_grads_with_norm_(
|
||||
|
||||
@_no_grad
|
||||
def clip_grad_norm_(
|
||||
parameters: _tensor_or_tensors,
|
||||
parameters: _TensorOrTensors,
|
||||
max_norm: float,
|
||||
norm_type: float = 2.0,
|
||||
error_if_nonfinite: bool = False,
|
||||
@ -235,7 +236,7 @@ def clip_grad_norm_(
|
||||
category=FutureWarning,
|
||||
)
|
||||
def clip_grad_norm(
|
||||
parameters: _tensor_or_tensors,
|
||||
parameters: _TensorOrTensors,
|
||||
max_norm: float,
|
||||
norm_type: float = 2.0,
|
||||
error_if_nonfinite: bool = False,
|
||||
@ -252,7 +253,7 @@ def clip_grad_norm(
|
||||
|
||||
@_no_grad
|
||||
def clip_grad_value_(
|
||||
parameters: _tensor_or_tensors,
|
||||
parameters: _TensorOrTensors,
|
||||
clip_value: float,
|
||||
foreach: Optional[bool] = None,
|
||||
) -> None:
|
||||
|
Reference in New Issue
Block a user