mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE][Ez]: Fully type nn.utils.clip_grad (#154801)
Full types clip_grad and exposed typing annotations that were hidden by a bad decorator Pull Request resolved: https://github.com/pytorch/pytorch/pull/154801 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
dbad6d71c7
commit
9ce2732b68
@ -1,11 +1,9 @@
|
|||||||
# mypy: allow-untyped-decorators
|
|
||||||
# mypy: allow-untyped-defs
|
|
||||||
import functools
|
import functools
|
||||||
import types
|
import types
|
||||||
import typing
|
import typing
|
||||||
import warnings
|
import warnings
|
||||||
from typing import cast, Optional, Union
|
from typing import Callable, cast, Optional, TypeVar, Union
|
||||||
from typing_extensions import deprecated
|
from typing_extensions import deprecated, ParamSpec, TypeAlias
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
@ -23,19 +21,22 @@ __all__ = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
_tensor_or_tensors = Union[
|
_TensorOrTensors: TypeAlias = Union[
|
||||||
torch.Tensor,
|
torch.Tensor,
|
||||||
typing.Iterable[torch.Tensor], # noqa: UP006 - needed until XLA's patch is updated
|
typing.Iterable[torch.Tensor], # noqa: UP006 - needed until XLA's patch is updated
|
||||||
]
|
]
|
||||||
|
|
||||||
|
_P = ParamSpec("_P")
|
||||||
|
_R = TypeVar("_R")
|
||||||
|
|
||||||
def _no_grad(func):
|
|
||||||
|
def _no_grad(func: Callable[_P, _R]) -> Callable[_P, _R]:
|
||||||
"""
|
"""
|
||||||
This wrapper is needed to avoid a circular import when using @torch.no_grad on the exposed functions
|
This wrapper is needed to avoid a circular import when using @torch.no_grad on the exposed functions
|
||||||
clip_grad_norm_ and clip_grad_value_ themselves.
|
clip_grad_norm_ and clip_grad_value_ themselves.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _no_grad_wrapper(*args, **kwargs):
|
def _no_grad_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
@ -45,7 +46,7 @@ def _no_grad(func):
|
|||||||
|
|
||||||
@_no_grad
|
@_no_grad
|
||||||
def _get_total_norm(
|
def _get_total_norm(
|
||||||
tensors: _tensor_or_tensors,
|
tensors: _TensorOrTensors,
|
||||||
norm_type: float = 2.0,
|
norm_type: float = 2.0,
|
||||||
error_if_nonfinite: bool = False,
|
error_if_nonfinite: bool = False,
|
||||||
foreach: Optional[bool] = None,
|
foreach: Optional[bool] = None,
|
||||||
@ -116,7 +117,7 @@ def _get_total_norm(
|
|||||||
|
|
||||||
@_no_grad
|
@_no_grad
|
||||||
def _clip_grads_with_norm_(
|
def _clip_grads_with_norm_(
|
||||||
parameters: _tensor_or_tensors,
|
parameters: _TensorOrTensors,
|
||||||
max_norm: float,
|
max_norm: float,
|
||||||
total_norm: torch.Tensor,
|
total_norm: torch.Tensor,
|
||||||
foreach: Optional[bool] = None,
|
foreach: Optional[bool] = None,
|
||||||
@ -180,7 +181,7 @@ def _clip_grads_with_norm_(
|
|||||||
|
|
||||||
@_no_grad
|
@_no_grad
|
||||||
def clip_grad_norm_(
|
def clip_grad_norm_(
|
||||||
parameters: _tensor_or_tensors,
|
parameters: _TensorOrTensors,
|
||||||
max_norm: float,
|
max_norm: float,
|
||||||
norm_type: float = 2.0,
|
norm_type: float = 2.0,
|
||||||
error_if_nonfinite: bool = False,
|
error_if_nonfinite: bool = False,
|
||||||
@ -235,7 +236,7 @@ def clip_grad_norm_(
|
|||||||
category=FutureWarning,
|
category=FutureWarning,
|
||||||
)
|
)
|
||||||
def clip_grad_norm(
|
def clip_grad_norm(
|
||||||
parameters: _tensor_or_tensors,
|
parameters: _TensorOrTensors,
|
||||||
max_norm: float,
|
max_norm: float,
|
||||||
norm_type: float = 2.0,
|
norm_type: float = 2.0,
|
||||||
error_if_nonfinite: bool = False,
|
error_if_nonfinite: bool = False,
|
||||||
@ -252,7 +253,7 @@ def clip_grad_norm(
|
|||||||
|
|
||||||
@_no_grad
|
@_no_grad
|
||||||
def clip_grad_value_(
|
def clip_grad_value_(
|
||||||
parameters: _tensor_or_tensors,
|
parameters: _TensorOrTensors,
|
||||||
clip_value: float,
|
clip_value: float,
|
||||||
foreach: Optional[bool] = None,
|
foreach: Optional[bool] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
Reference in New Issue
Block a user