mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Apply UFMT to low traffic torch modules (#106249)
Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/106249 Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
a4ebc61f15
commit
3bf922a6ce
@ -1 +1 @@
|
||||
from .autocast_mode import autocast, _enter_autocast, _exit_autocast
|
||||
from .autocast_mode import _enter_autocast, _exit_autocast, autocast
|
||||
|
@ -1,20 +1,24 @@
|
||||
import torch
|
||||
import functools
|
||||
import warnings
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch.types import _dtype
|
||||
|
||||
__all__ = ['autocast_decorator', 'autocast']
|
||||
__all__ = ["autocast_decorator", "autocast"]
|
||||
|
||||
|
||||
def autocast_decorator(autocast_instance, func):
|
||||
@functools.wraps(func)
|
||||
def decorate_autocast(*args, **kwargs):
|
||||
with autocast_instance:
|
||||
return func(*args, **kwargs)
|
||||
decorate_autocast.__script_unsupported = '@autocast() decorator is not supported in script mode' # type: ignore[attr-defined]
|
||||
|
||||
decorate_autocast.__script_unsupported = "@autocast() decorator is not supported in script mode" # type: ignore[attr-defined]
|
||||
return decorate_autocast
|
||||
|
||||
|
||||
class autocast:
|
||||
r"""
|
||||
Instances of :class:`autocast` serve as context managers or decorators that
|
||||
@ -179,10 +183,14 @@ class autocast:
|
||||
cache_enabled(bool, optional): Whether the weight cache inside autocast should be enabled.
|
||||
Default: ``True``
|
||||
"""
|
||||
def __init__(self, device_type : str,
|
||||
dtype : Optional[_dtype] = None,
|
||||
enabled : bool = True,
|
||||
cache_enabled : Optional[bool] = None):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device_type: str,
|
||||
dtype: Optional[_dtype] = None,
|
||||
enabled: bool = True,
|
||||
cache_enabled: Optional[bool] = None,
|
||||
):
|
||||
if torch._jit_internal.is_scripting():
|
||||
self._enabled = enabled
|
||||
self.device = device_type
|
||||
@ -192,71 +200,90 @@ class autocast:
|
||||
return
|
||||
self.device = device_type
|
||||
self.custom_backend_name = torch._C._get_privateuse1_backend_name()
|
||||
if self.device == 'cuda':
|
||||
if self.device == "cuda":
|
||||
self.fast_dtype = torch.get_autocast_gpu_dtype()
|
||||
elif self.device == 'cpu':
|
||||
elif self.device == "cpu":
|
||||
self.fast_dtype = torch.get_autocast_cpu_dtype()
|
||||
elif self.device == 'xpu':
|
||||
elif self.device == "xpu":
|
||||
self.fast_dtype = torch.xpu.get_autocast_xpu_dtype() # type: ignore[attr-defined]
|
||||
elif self.device == 'ipu':
|
||||
elif self.device == "ipu":
|
||||
self.fast_dtype = torch.get_autocast_ipu_dtype() # type: ignore[attr-defined]
|
||||
elif self.device == 'hpu':
|
||||
elif self.device == "hpu":
|
||||
self.fast_dtype = torch.hpu.get_autocast_hpu_dtype() # type: ignore[attr-defined]
|
||||
elif self.device == 'xla':
|
||||
elif self.device == "xla":
|
||||
self.fast_dtype = torch.get_autocast_xla_dtype() # type: ignore[attr-defined]
|
||||
elif self.device == self.custom_backend_name:
|
||||
necessary_funcs = ['is_autocast_enabled', 'set_autocast_enabled', 'get_autocast_dtype',
|
||||
'set_autocast_dtype', 'get_amp_supported_dtype']
|
||||
necessary_funcs = [
|
||||
"is_autocast_enabled",
|
||||
"set_autocast_enabled",
|
||||
"get_autocast_dtype",
|
||||
"set_autocast_dtype",
|
||||
"get_amp_supported_dtype",
|
||||
]
|
||||
message = f"Tried to use AMP with the `{self.custom_backend_name}` backend, but the backend has not "
|
||||
message += "registered a module or the module miss some necessary funcs. The backend should register "
|
||||
message += "a module by `torch._register_device_module`, and the module must have these funcs: \n"
|
||||
message += "`is_autocast_enabled() -> bool`, `set_autocast_enabled(bool) -> None`, "
|
||||
message += "`get_autocast_dtype() -> torch.dtype`, `set_autocast_dtype(torch.dtype) "
|
||||
message += "-> None` and `get_amp_supported_dtype() -> List[torch.dtype]`. \n"
|
||||
message += (
|
||||
"-> None` and `get_amp_supported_dtype() -> List[torch.dtype]`. \n"
|
||||
)
|
||||
|
||||
assert hasattr(torch, self.custom_backend_name), message
|
||||
self.custom_device_mod = getattr(torch, self.custom_backend_name)
|
||||
for func in necessary_funcs:
|
||||
assert hasattr(self.custom_device_mod, func), message + f"But the func `{func}` is missing. \n"
|
||||
assert hasattr(self.custom_device_mod, func), (
|
||||
message + f"But the func `{func}` is missing. \n"
|
||||
)
|
||||
|
||||
self.fast_dtype = self.custom_device_mod.get_autocast_dtype()
|
||||
else:
|
||||
raise RuntimeError(f'User specified an unsupported autocast device_type \'{self.device}\'')
|
||||
raise RuntimeError(
|
||||
f"User specified an unsupported autocast device_type '{self.device}'"
|
||||
)
|
||||
self._cache_enabled = torch.is_autocast_cache_enabled()
|
||||
if enabled and torch.cuda.amp.common.amp_definitely_not_available() and self.device == 'cuda':
|
||||
warnings.warn('User provided device_type of \'cuda\', but CUDA is not available. Disabling')
|
||||
if (
|
||||
enabled
|
||||
and torch.cuda.amp.common.amp_definitely_not_available()
|
||||
and self.device == "cuda"
|
||||
):
|
||||
warnings.warn(
|
||||
"User provided device_type of 'cuda', but CUDA is not available. Disabling"
|
||||
)
|
||||
enabled = False
|
||||
if dtype is not None:
|
||||
self.fast_dtype = dtype
|
||||
if cache_enabled is not None:
|
||||
self._cache_enabled = cache_enabled
|
||||
|
||||
if self.device == 'cpu':
|
||||
if self.device == "cpu":
|
||||
supported_dtype = [torch.bfloat16]
|
||||
if self.fast_dtype not in supported_dtype:
|
||||
error_message = 'In CPU autocast, but the target dtype is not supported. Disabling autocast.\n'
|
||||
error_message += 'CPU Autocast only supports dtype of torch.bfloat16 currently.'
|
||||
error_message = "In CPU autocast, but the target dtype is not supported. Disabling autocast.\n"
|
||||
error_message += (
|
||||
"CPU Autocast only supports dtype of torch.bfloat16 currently."
|
||||
)
|
||||
warnings.warn(error_message)
|
||||
enabled = False
|
||||
elif self.device == 'xpu':
|
||||
elif self.device == "xpu":
|
||||
supported_dtype = [torch.bfloat16, torch.float16]
|
||||
if self.fast_dtype not in supported_dtype:
|
||||
error_message = 'In XPU autocast, but the target dtype is not supported. Disabling autocast.\n'
|
||||
error_message += 'XPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently.'
|
||||
error_message = "In XPU autocast, but the target dtype is not supported. Disabling autocast.\n"
|
||||
error_message += "XPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
|
||||
warnings.warn(error_message)
|
||||
enabled = False
|
||||
elif self.device == 'ipu':
|
||||
elif self.device == "ipu":
|
||||
supported_dtypes = [torch.bfloat16, torch.float16]
|
||||
if self.fast_dtype not in supported_dtypes:
|
||||
error_message = 'In IPU autocast, but the target dtype is not supported. Disabling autocast.\n'
|
||||
error_message += 'IPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently.'
|
||||
error_message = "In IPU autocast, but the target dtype is not supported. Disabling autocast.\n"
|
||||
error_message += "IPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
|
||||
warnings.warn(error_message)
|
||||
enabled = False
|
||||
elif self.device == 'hpu':
|
||||
elif self.device == "hpu":
|
||||
supported_dtype = [torch.bfloat16, torch.float16]
|
||||
if self.fast_dtype not in supported_dtype:
|
||||
error_message = 'In HPU autocast, but the target dtype is not supported. Disabling autocast.\n'
|
||||
error_message += 'HPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently.'
|
||||
error_message = "In HPU autocast, but the target dtype is not supported. Disabling autocast.\n"
|
||||
error_message += "HPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
|
||||
warnings.warn(error_message)
|
||||
enabled = False
|
||||
elif self.device == self.custom_backend_name:
|
||||
@ -264,17 +291,27 @@ class autocast:
|
||||
if self.fast_dtype not in supported_dtype:
|
||||
error_message = f"In {self.custom_backend_name} autocast, but the target dtype is not supported. "
|
||||
error_message += f"Disabling autocast.\n {self.custom_backend_name} Autocast only supports dtypes of "
|
||||
error_message += ", ".join(str(dtype) for dtype in supported_dtype) + " currently."
|
||||
error_message += (
|
||||
", ".join(str(dtype) for dtype in supported_dtype) + " currently."
|
||||
)
|
||||
warnings.warn(error_message)
|
||||
enabled = False
|
||||
elif self.device == 'cuda':
|
||||
if enabled and self.fast_dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported():
|
||||
raise RuntimeError('Current CUDA Device does not support bfloat16. Please switch dtype to float16.')
|
||||
elif self.device == 'xla':
|
||||
elif self.device == "cuda":
|
||||
if (
|
||||
enabled
|
||||
and self.fast_dtype == torch.bfloat16
|
||||
and not torch.cuda.is_bf16_supported()
|
||||
):
|
||||
raise RuntimeError(
|
||||
"Current CUDA Device does not support bfloat16. Please switch dtype to float16."
|
||||
)
|
||||
elif self.device == "xla":
|
||||
supported_dtype = [torch.bfloat16]
|
||||
if self.fast_dtype not in supported_dtype:
|
||||
error_message = 'In XLA autocast, but the target dtype is not supported. Disabling autocast.\n'
|
||||
error_message += 'XLA Autocast only supports dtype of torch.bfloat16 currently.'
|
||||
error_message = "In XLA autocast, but the target dtype is not supported. Disabling autocast.\n"
|
||||
error_message += (
|
||||
"XLA Autocast only supports dtype of torch.bfloat16 currently."
|
||||
)
|
||||
warnings.warn(error_message)
|
||||
enabled = False
|
||||
self._enabled = enabled
|
||||
@ -285,31 +322,31 @@ class autocast:
|
||||
return self
|
||||
|
||||
self.prev_cache_enabled = torch.is_autocast_cache_enabled()
|
||||
if self.device == 'cpu':
|
||||
if self.device == "cpu":
|
||||
self.prev = torch.is_autocast_cpu_enabled()
|
||||
self.prev_fastdtype = torch.get_autocast_cpu_dtype()
|
||||
torch.set_autocast_cpu_enabled(self._enabled)
|
||||
torch.set_autocast_cpu_dtype(self.fast_dtype) # type: ignore[arg-type]
|
||||
torch.autocast_increment_nesting()
|
||||
elif self.device == 'xpu':
|
||||
self.prev = torch.xpu.is_autocast_xpu_enabled() # type: ignore[attr-defined]
|
||||
elif self.device == "xpu":
|
||||
self.prev = torch.xpu.is_autocast_xpu_enabled() # type: ignore[attr-defined]
|
||||
self.prev_fastdtype = torch.xpu.get_autocast_xpu_dtype() # type: ignore[attr-defined]
|
||||
torch.xpu.set_autocast_xpu_enabled(self._enabled) # type: ignore[attr-defined]
|
||||
torch.xpu.set_autocast_xpu_dtype(self.fast_dtype) # type: ignore[attr-defined]
|
||||
torch.autocast_increment_nesting()
|
||||
elif self.device == 'ipu':
|
||||
self.prev = torch.is_autocast_ipu_enabled() # type: ignore[attr-defined]
|
||||
elif self.device == "ipu":
|
||||
self.prev = torch.is_autocast_ipu_enabled() # type: ignore[attr-defined]
|
||||
self.prev_fastdtype = torch.get_autocast_ipu_dtype() # type: ignore[attr-defined]
|
||||
torch.set_autocast_ipu_enabled(self._enabled) # type: ignore[attr-defined]
|
||||
torch.set_autocast_ipu_dtype(self.fast_dtype) # type: ignore[attr-defined]
|
||||
torch.autocast_increment_nesting()
|
||||
elif self.device == 'hpu':
|
||||
self.prev = torch.hpu.is_autocast_hpu_enabled() # type: ignore[attr-defined]
|
||||
elif self.device == "hpu":
|
||||
self.prev = torch.hpu.is_autocast_hpu_enabled() # type: ignore[attr-defined]
|
||||
self.prev_fastdtype = torch.hpu.get_autocast_hpu_dtype() # type: ignore[attr-defined]
|
||||
torch.hpu.set_autocast_hpu_enabled(self._enabled) # type: ignore[attr-defined]
|
||||
torch.hpu.set_autocast_hpu_dtype(self.fast_dtype) # type: ignore[attr-defined]
|
||||
torch.autocast_increment_nesting()
|
||||
elif self.device == 'xla':
|
||||
elif self.device == "xla":
|
||||
self.prev = torch.is_autocast_xla_enabled() # type: ignore[attr-defined]
|
||||
self.prev_fastdtype = torch.get_autocast_xla_dtype() # type: ignore[attr-defined]
|
||||
torch.set_autocast_xla_enabled(self._enabled) # type: ignore[attr-defined]
|
||||
@ -334,31 +371,31 @@ class autocast:
|
||||
return
|
||||
|
||||
# Drop the cache when we exit to a nesting level that's outside any instance of autocast.
|
||||
if self.device == 'cpu':
|
||||
if self.device == "cpu":
|
||||
if torch.autocast_decrement_nesting() == 0:
|
||||
torch.clear_autocast_cache()
|
||||
torch.set_autocast_cpu_enabled(self.prev)
|
||||
torch.set_autocast_cpu_dtype(self.prev_fastdtype)
|
||||
elif self.device == 'xpu':
|
||||
elif self.device == "xpu":
|
||||
if torch.autocast_decrement_nesting() == 0:
|
||||
torch.clear_autocast_cache()
|
||||
torch.xpu.set_autocast_xpu_enabled(self.prev) # type: ignore[attr-defined]
|
||||
torch.xpu.set_autocast_xpu_dtype(self.prev_fastdtype) # type: ignore[attr-defined]
|
||||
elif self.device == 'ipu':
|
||||
torch.xpu.set_autocast_xpu_enabled(self.prev) # type: ignore[attr-defined]
|
||||
torch.xpu.set_autocast_xpu_dtype(self.prev_fastdtype) # type: ignore[attr-defined]
|
||||
elif self.device == "ipu":
|
||||
if torch.autocast_decrement_nesting() == 0:
|
||||
torch.clear_autocast_cache()
|
||||
torch.set_autocast_ipu_enabled(self.prev) # type: ignore[attr-defined]
|
||||
torch.set_autocast_ipu_dtype(self.prev_fastdtype) # type: ignore[attr-defined]
|
||||
elif self.device == 'hpu':
|
||||
torch.set_autocast_ipu_enabled(self.prev) # type: ignore[attr-defined]
|
||||
torch.set_autocast_ipu_dtype(self.prev_fastdtype) # type: ignore[attr-defined]
|
||||
elif self.device == "hpu":
|
||||
if torch.autocast_decrement_nesting() == 0:
|
||||
torch.clear_autocast_cache()
|
||||
torch.hpu.set_autocast_hpu_enabled(self.prev) # type: ignore[attr-defined]
|
||||
torch.hpu.set_autocast_hpu_dtype(self.prev_fastdtype) # type: ignore[attr-defined]
|
||||
elif self.device == 'xla':
|
||||
torch.hpu.set_autocast_hpu_enabled(self.prev) # type: ignore[attr-defined]
|
||||
torch.hpu.set_autocast_hpu_dtype(self.prev_fastdtype) # type: ignore[attr-defined]
|
||||
elif self.device == "xla":
|
||||
if torch.autocast_decrement_nesting() == 0:
|
||||
torch.clear_autocast_cache()
|
||||
torch.set_autocast_xla_enabled(self.prev) # type: ignore[attr-defined]
|
||||
torch.set_autocast_xla_dtype(self.prev_fastdtype) # type: ignore[attr-defined]
|
||||
torch.set_autocast_xla_enabled(self.prev) # type: ignore[attr-defined]
|
||||
torch.set_autocast_xla_dtype(self.prev_fastdtype) # type: ignore[attr-defined]
|
||||
elif self.device == self.custom_backend_name:
|
||||
if torch.autocast_decrement_nesting() == 0:
|
||||
torch.clear_autocast_cache()
|
||||
@ -377,13 +414,16 @@ class autocast:
|
||||
return func
|
||||
return autocast_decorator(self, func)
|
||||
|
||||
|
||||
# These functions aren't meant for public usage.
|
||||
# They are what we trace into a graph during pre_dispatch tracing
|
||||
# when we encounter an autocast context manager.
|
||||
def _enter_autocast(*vals):
|
||||
# For pre-dispatch tracing, if a TorchFunction mode is active, we'll want to trace this into a graph.
|
||||
if torch._C._is_torch_function_mode_enabled():
|
||||
return torch.overrides.handle_torch_function(torch.amp._enter_autocast, [], *vals)
|
||||
return torch.overrides.handle_torch_function(
|
||||
torch.amp._enter_autocast, [], *vals
|
||||
)
|
||||
mode = torch.amp.autocast(*vals)
|
||||
mode.__enter__()
|
||||
return mode
|
||||
|
Reference in New Issue
Block a user