Apply UFMT to low traffic torch modules (#106249)

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106249
Approved by: https://github.com/Skylion007
This commit is contained in:
Edward Z. Yang
2023-07-29 10:51:26 -04:00
committed by PyTorch MergeBot
parent a4ebc61f15
commit 3bf922a6ce
163 changed files with 8472 additions and 4412 deletions

View File

@ -1 +1 @@
from .autocast_mode import autocast, _enter_autocast, _exit_autocast
from .autocast_mode import _enter_autocast, _exit_autocast, autocast

View File

@ -1,20 +1,24 @@
import torch
import functools
import warnings
from typing import Any, Optional
import torch
from torch.types import _dtype
__all__ = ['autocast_decorator', 'autocast']
__all__ = ["autocast_decorator", "autocast"]
def autocast_decorator(autocast_instance, func):
@functools.wraps(func)
def decorate_autocast(*args, **kwargs):
with autocast_instance:
return func(*args, **kwargs)
decorate_autocast.__script_unsupported = '@autocast() decorator is not supported in script mode' # type: ignore[attr-defined]
decorate_autocast.__script_unsupported = "@autocast() decorator is not supported in script mode" # type: ignore[attr-defined]
return decorate_autocast
class autocast:
r"""
Instances of :class:`autocast` serve as context managers or decorators that
@ -179,10 +183,14 @@ class autocast:
cache_enabled(bool, optional): Whether the weight cache inside autocast should be enabled.
Default: ``True``
"""
def __init__(self, device_type : str,
dtype : Optional[_dtype] = None,
enabled : bool = True,
cache_enabled : Optional[bool] = None):
def __init__(
self,
device_type: str,
dtype: Optional[_dtype] = None,
enabled: bool = True,
cache_enabled: Optional[bool] = None,
):
if torch._jit_internal.is_scripting():
self._enabled = enabled
self.device = device_type
@ -192,71 +200,90 @@ class autocast:
return
self.device = device_type
self.custom_backend_name = torch._C._get_privateuse1_backend_name()
if self.device == 'cuda':
if self.device == "cuda":
self.fast_dtype = torch.get_autocast_gpu_dtype()
elif self.device == 'cpu':
elif self.device == "cpu":
self.fast_dtype = torch.get_autocast_cpu_dtype()
elif self.device == 'xpu':
elif self.device == "xpu":
self.fast_dtype = torch.xpu.get_autocast_xpu_dtype() # type: ignore[attr-defined]
elif self.device == 'ipu':
elif self.device == "ipu":
self.fast_dtype = torch.get_autocast_ipu_dtype() # type: ignore[attr-defined]
elif self.device == 'hpu':
elif self.device == "hpu":
self.fast_dtype = torch.hpu.get_autocast_hpu_dtype() # type: ignore[attr-defined]
elif self.device == 'xla':
elif self.device == "xla":
self.fast_dtype = torch.get_autocast_xla_dtype() # type: ignore[attr-defined]
elif self.device == self.custom_backend_name:
necessary_funcs = ['is_autocast_enabled', 'set_autocast_enabled', 'get_autocast_dtype',
'set_autocast_dtype', 'get_amp_supported_dtype']
necessary_funcs = [
"is_autocast_enabled",
"set_autocast_enabled",
"get_autocast_dtype",
"set_autocast_dtype",
"get_amp_supported_dtype",
]
message = f"Tried to use AMP with the `{self.custom_backend_name}` backend, but the backend has not "
message += "registered a module or the module miss some necessary funcs. The backend should register "
message += "a module by `torch._register_device_module`, and the module must have these funcs: \n"
message += "`is_autocast_enabled() -> bool`, `set_autocast_enabled(bool) -> None`, "
message += "`get_autocast_dtype() -> torch.dtype`, `set_autocast_dtype(torch.dtype) "
message += "-> None` and `get_amp_supported_dtype() -> List[torch.dtype]`. \n"
message += (
"-> None` and `get_amp_supported_dtype() -> List[torch.dtype]`. \n"
)
assert hasattr(torch, self.custom_backend_name), message
self.custom_device_mod = getattr(torch, self.custom_backend_name)
for func in necessary_funcs:
assert hasattr(self.custom_device_mod, func), message + f"But the func `{func}` is missing. \n"
assert hasattr(self.custom_device_mod, func), (
message + f"But the func `{func}` is missing. \n"
)
self.fast_dtype = self.custom_device_mod.get_autocast_dtype()
else:
raise RuntimeError(f'User specified an unsupported autocast device_type \'{self.device}\'')
raise RuntimeError(
f"User specified an unsupported autocast device_type '{self.device}'"
)
self._cache_enabled = torch.is_autocast_cache_enabled()
if enabled and torch.cuda.amp.common.amp_definitely_not_available() and self.device == 'cuda':
warnings.warn('User provided device_type of \'cuda\', but CUDA is not available. Disabling')
if (
enabled
and torch.cuda.amp.common.amp_definitely_not_available()
and self.device == "cuda"
):
warnings.warn(
"User provided device_type of 'cuda', but CUDA is not available. Disabling"
)
enabled = False
if dtype is not None:
self.fast_dtype = dtype
if cache_enabled is not None:
self._cache_enabled = cache_enabled
if self.device == 'cpu':
if self.device == "cpu":
supported_dtype = [torch.bfloat16]
if self.fast_dtype not in supported_dtype:
error_message = 'In CPU autocast, but the target dtype is not supported. Disabling autocast.\n'
error_message += 'CPU Autocast only supports dtype of torch.bfloat16 currently.'
error_message = "In CPU autocast, but the target dtype is not supported. Disabling autocast.\n"
error_message += (
"CPU Autocast only supports dtype of torch.bfloat16 currently."
)
warnings.warn(error_message)
enabled = False
elif self.device == 'xpu':
elif self.device == "xpu":
supported_dtype = [torch.bfloat16, torch.float16]
if self.fast_dtype not in supported_dtype:
error_message = 'In XPU autocast, but the target dtype is not supported. Disabling autocast.\n'
error_message += 'XPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently.'
error_message = "In XPU autocast, but the target dtype is not supported. Disabling autocast.\n"
error_message += "XPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
warnings.warn(error_message)
enabled = False
elif self.device == 'ipu':
elif self.device == "ipu":
supported_dtypes = [torch.bfloat16, torch.float16]
if self.fast_dtype not in supported_dtypes:
error_message = 'In IPU autocast, but the target dtype is not supported. Disabling autocast.\n'
error_message += 'IPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently.'
error_message = "In IPU autocast, but the target dtype is not supported. Disabling autocast.\n"
error_message += "IPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
warnings.warn(error_message)
enabled = False
elif self.device == 'hpu':
elif self.device == "hpu":
supported_dtype = [torch.bfloat16, torch.float16]
if self.fast_dtype not in supported_dtype:
error_message = 'In HPU autocast, but the target dtype is not supported. Disabling autocast.\n'
error_message += 'HPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently.'
error_message = "In HPU autocast, but the target dtype is not supported. Disabling autocast.\n"
error_message += "HPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
warnings.warn(error_message)
enabled = False
elif self.device == self.custom_backend_name:
@ -264,17 +291,27 @@ class autocast:
if self.fast_dtype not in supported_dtype:
error_message = f"In {self.custom_backend_name} autocast, but the target dtype is not supported. "
error_message += f"Disabling autocast.\n {self.custom_backend_name} Autocast only supports dtypes of "
error_message += ", ".join(str(dtype) for dtype in supported_dtype) + " currently."
error_message += (
", ".join(str(dtype) for dtype in supported_dtype) + " currently."
)
warnings.warn(error_message)
enabled = False
elif self.device == 'cuda':
if enabled and self.fast_dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported():
raise RuntimeError('Current CUDA Device does not support bfloat16. Please switch dtype to float16.')
elif self.device == 'xla':
elif self.device == "cuda":
if (
enabled
and self.fast_dtype == torch.bfloat16
and not torch.cuda.is_bf16_supported()
):
raise RuntimeError(
"Current CUDA Device does not support bfloat16. Please switch dtype to float16."
)
elif self.device == "xla":
supported_dtype = [torch.bfloat16]
if self.fast_dtype not in supported_dtype:
error_message = 'In XLA autocast, but the target dtype is not supported. Disabling autocast.\n'
error_message += 'XLA Autocast only supports dtype of torch.bfloat16 currently.'
error_message = "In XLA autocast, but the target dtype is not supported. Disabling autocast.\n"
error_message += (
"XLA Autocast only supports dtype of torch.bfloat16 currently."
)
warnings.warn(error_message)
enabled = False
self._enabled = enabled
@ -285,31 +322,31 @@ class autocast:
return self
self.prev_cache_enabled = torch.is_autocast_cache_enabled()
if self.device == 'cpu':
if self.device == "cpu":
self.prev = torch.is_autocast_cpu_enabled()
self.prev_fastdtype = torch.get_autocast_cpu_dtype()
torch.set_autocast_cpu_enabled(self._enabled)
torch.set_autocast_cpu_dtype(self.fast_dtype) # type: ignore[arg-type]
torch.autocast_increment_nesting()
elif self.device == 'xpu':
self.prev = torch.xpu.is_autocast_xpu_enabled() # type: ignore[attr-defined]
elif self.device == "xpu":
self.prev = torch.xpu.is_autocast_xpu_enabled() # type: ignore[attr-defined]
self.prev_fastdtype = torch.xpu.get_autocast_xpu_dtype() # type: ignore[attr-defined]
torch.xpu.set_autocast_xpu_enabled(self._enabled) # type: ignore[attr-defined]
torch.xpu.set_autocast_xpu_dtype(self.fast_dtype) # type: ignore[attr-defined]
torch.autocast_increment_nesting()
elif self.device == 'ipu':
self.prev = torch.is_autocast_ipu_enabled() # type: ignore[attr-defined]
elif self.device == "ipu":
self.prev = torch.is_autocast_ipu_enabled() # type: ignore[attr-defined]
self.prev_fastdtype = torch.get_autocast_ipu_dtype() # type: ignore[attr-defined]
torch.set_autocast_ipu_enabled(self._enabled) # type: ignore[attr-defined]
torch.set_autocast_ipu_dtype(self.fast_dtype) # type: ignore[attr-defined]
torch.autocast_increment_nesting()
elif self.device == 'hpu':
self.prev = torch.hpu.is_autocast_hpu_enabled() # type: ignore[attr-defined]
elif self.device == "hpu":
self.prev = torch.hpu.is_autocast_hpu_enabled() # type: ignore[attr-defined]
self.prev_fastdtype = torch.hpu.get_autocast_hpu_dtype() # type: ignore[attr-defined]
torch.hpu.set_autocast_hpu_enabled(self._enabled) # type: ignore[attr-defined]
torch.hpu.set_autocast_hpu_dtype(self.fast_dtype) # type: ignore[attr-defined]
torch.autocast_increment_nesting()
elif self.device == 'xla':
elif self.device == "xla":
self.prev = torch.is_autocast_xla_enabled() # type: ignore[attr-defined]
self.prev_fastdtype = torch.get_autocast_xla_dtype() # type: ignore[attr-defined]
torch.set_autocast_xla_enabled(self._enabled) # type: ignore[attr-defined]
@ -334,31 +371,31 @@ class autocast:
return
# Drop the cache when we exit to a nesting level that's outside any instance of autocast.
if self.device == 'cpu':
if self.device == "cpu":
if torch.autocast_decrement_nesting() == 0:
torch.clear_autocast_cache()
torch.set_autocast_cpu_enabled(self.prev)
torch.set_autocast_cpu_dtype(self.prev_fastdtype)
elif self.device == 'xpu':
elif self.device == "xpu":
if torch.autocast_decrement_nesting() == 0:
torch.clear_autocast_cache()
torch.xpu.set_autocast_xpu_enabled(self.prev) # type: ignore[attr-defined]
torch.xpu.set_autocast_xpu_dtype(self.prev_fastdtype) # type: ignore[attr-defined]
elif self.device == 'ipu':
torch.xpu.set_autocast_xpu_enabled(self.prev) # type: ignore[attr-defined]
torch.xpu.set_autocast_xpu_dtype(self.prev_fastdtype) # type: ignore[attr-defined]
elif self.device == "ipu":
if torch.autocast_decrement_nesting() == 0:
torch.clear_autocast_cache()
torch.set_autocast_ipu_enabled(self.prev) # type: ignore[attr-defined]
torch.set_autocast_ipu_dtype(self.prev_fastdtype) # type: ignore[attr-defined]
elif self.device == 'hpu':
torch.set_autocast_ipu_enabled(self.prev) # type: ignore[attr-defined]
torch.set_autocast_ipu_dtype(self.prev_fastdtype) # type: ignore[attr-defined]
elif self.device == "hpu":
if torch.autocast_decrement_nesting() == 0:
torch.clear_autocast_cache()
torch.hpu.set_autocast_hpu_enabled(self.prev) # type: ignore[attr-defined]
torch.hpu.set_autocast_hpu_dtype(self.prev_fastdtype) # type: ignore[attr-defined]
elif self.device == 'xla':
torch.hpu.set_autocast_hpu_enabled(self.prev) # type: ignore[attr-defined]
torch.hpu.set_autocast_hpu_dtype(self.prev_fastdtype) # type: ignore[attr-defined]
elif self.device == "xla":
if torch.autocast_decrement_nesting() == 0:
torch.clear_autocast_cache()
torch.set_autocast_xla_enabled(self.prev) # type: ignore[attr-defined]
torch.set_autocast_xla_dtype(self.prev_fastdtype) # type: ignore[attr-defined]
torch.set_autocast_xla_enabled(self.prev) # type: ignore[attr-defined]
torch.set_autocast_xla_dtype(self.prev_fastdtype) # type: ignore[attr-defined]
elif self.device == self.custom_backend_name:
if torch.autocast_decrement_nesting() == 0:
torch.clear_autocast_cache()
@ -377,13 +414,16 @@ class autocast:
return func
return autocast_decorator(self, func)
# These functions aren't meant for public usage.
# They are what we trace into a graph during pre_dispatch tracing
# when we encounter an autocast context manager.
def _enter_autocast(*vals):
# For pre-dispatch tracing, if a TorchFunction mode is active, we'll want to trace this into a graph.
if torch._C._is_torch_function_mode_enabled():
return torch.overrides.handle_torch_function(torch.amp._enter_autocast, [], *vals)
return torch.overrides.handle_torch_function(
torch.amp._enter_autocast, [], *vals
)
mode = torch.amp.autocast(*vals)
mode.__enter__()
return mode