diff --git a/torch/cpu/amp/autocast_mode.py b/torch/cpu/amp/autocast_mode.py index 630fe9e58a46..f0f81060d4a0 100644 --- a/torch/cpu/amp/autocast_mode.py +++ b/torch/cpu/amp/autocast_mode.py @@ -1,4 +1,5 @@ # mypy: allow-untyped-defs +import sys from typing import Any from typing_extensions import deprecated @@ -8,17 +9,36 @@ import torch __all__ = ["autocast"] +@deprecated( + "`torch.cpu.amp.autocast(args...)` is deprecated. " + "Please use `torch.amp.autocast('cpu', args...)` instead.", + category=FutureWarning, +) class autocast(torch.amp.autocast_mode.autocast): r""" See :class:`torch.autocast`. ``torch.cpu.amp.autocast(args...)`` is deprecated. Please use ``torch.amp.autocast("cpu", args...)`` instead. """ - @deprecated( - "`torch.cpu.amp.autocast(args...)` is deprecated. " - "Please use `torch.amp.autocast('cpu', args...)` instead.", - category=FutureWarning, - ) + # TODO: remove this conditional once we stop supporting Python < 3.13 + # Prior to Python 3.13, inspect.signature could not retrieve the correct + # signature information for classes decorated with @deprecated (unless + # the __new__ static method was explicitly defined); + # + # However, this issue has been fixed in Python 3.13 and later versions. + if sys.version_info < (3, 13): + + def __new__( + cls, + enabled: bool = True, + dtype: torch.dtype = torch.bfloat16, + cache_enabled: bool = True, + ): + return super().__new__(cls) + + def __init_subclass__(cls): + pass + def __init__( self, enabled: bool = True, diff --git a/torch/cuda/amp/autocast_mode.py b/torch/cuda/amp/autocast_mode.py index d52ff7cf672b..e6b63c708d3f 100644 --- a/torch/cuda/amp/autocast_mode.py +++ b/torch/cuda/amp/autocast_mode.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-defs import functools +import sys from typing import Any from typing_extensions import deprecated @@ -9,17 +10,36 @@ import torch __all__ = ["autocast", "custom_fwd", "custom_bwd"] +@deprecated( + "`torch.cuda.amp.autocast(args...)` is deprecated. " + "Please use `torch.amp.autocast('cuda', args...)` instead.", + category=FutureWarning, +) class autocast(torch.amp.autocast_mode.autocast): r"""See :class:`torch.autocast`. ``torch.cuda.amp.autocast(args...)`` is deprecated. Please use ``torch.amp.autocast("cuda", args...)`` instead. """ - @deprecated( - "`torch.cuda.amp.autocast(args...)` is deprecated. " - "Please use `torch.amp.autocast('cuda', args...)` instead.", - category=FutureWarning, - ) + # TODO: remove this conditional once we stop supporting Python < 3.13 + # Prior to Python 3.13, inspect.signature could not retrieve the correct + # signature information for classes decorated with @deprecated (unless + # the __new__ static method was explicitly defined); + # + # However, this issue has been fixed in Python 3.13 and later versions. + if sys.version_info < (3, 13): + + def __new__( + cls, + enabled: bool = True, + dtype: torch.dtype = torch.float16, + cache_enabled: bool = True, + ): + return super().__new__(cls) + + def __init_subclass__(cls): + pass + def __init__( self, enabled: bool = True,