mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[AMP][Refactor] Simplify dtype support logic in autocast context manager (#163446)
## Description:
This PR refactors the autocast context manager in `autocast_mode.py` to simplify and centralize the logic for checking supported dtypes for each device. The previous implementation repeated similar checks for multiple device types. Now, a single mapping `device_supported_dtypes` is used to associate device types with their supported dtypes, and the validation logic is unified.
In my view, this makes the code easier to maintain and extend for new devices.
Please share any suggestions and comments with me.
BTW, in the original `xla` branch, the `supported_dtype` are `[torch.float16, torch.bfloat16]`, 5d8a226e23/torch/amp/autocast_mode.py (L358-L363)
but the warning message has only `torch.bfloat16`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163446
Approved by: https://github.com/FFFrog, https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
e0abcee3b5
commit
960b0d5f0d
@ -346,7 +346,7 @@ class TestAutocastMPS(TestCase):
|
||||
def test_mps_autocast_error_message(self):
|
||||
with self.assertWarnsRegex(
|
||||
UserWarning,
|
||||
"MPS Autocast only supports dtype of torch.bfloat16 and torch.float16 currently.",
|
||||
"MPS Autocast only supports dtypes of torch.bfloat16, torch.float16 currently.",
|
||||
):
|
||||
with torch.autocast(device_type="mps", dtype=torch.float32):
|
||||
_ = torch.ones(10)
|
||||
|
@ -230,9 +230,9 @@ class autocast:
|
||||
raise ValueError(
|
||||
f"Expected `device_type` of type `str`, got: `{type(device_type)}`"
|
||||
)
|
||||
if dtype is None:
|
||||
dtype = torch.get_autocast_dtype(device_type)
|
||||
self.fast_dtype = dtype
|
||||
self.fast_dtype = (
|
||||
torch.get_autocast_dtype(device_type) if dtype is None else dtype
|
||||
)
|
||||
if torch._jit_internal.is_scripting():
|
||||
self._enabled = enabled
|
||||
self.device = device_type
|
||||
@ -243,6 +243,9 @@ class autocast:
|
||||
raise RuntimeError(
|
||||
f"User specified an unsupported autocast device_type '{self.device}'"
|
||||
)
|
||||
|
||||
device_supported_dtypes = [torch.bfloat16, torch.float16]
|
||||
|
||||
self.custom_backend_name = torch._C._get_privateuse1_backend_name()
|
||||
if self.device == self.custom_backend_name:
|
||||
necessary_funcs = [
|
||||
@ -259,110 +262,55 @@ class autocast:
|
||||
assert hasattr(self.custom_device_mod, func), (
|
||||
message + f"But the func `{func}` is missing. \n"
|
||||
)
|
||||
device_supported_dtypes = self.custom_device_mod.get_amp_supported_dtype()
|
||||
|
||||
self._cache_enabled = torch.is_autocast_cache_enabled()
|
||||
if (
|
||||
enabled
|
||||
and self.device == "cuda"
|
||||
and torch.cuda.amp.common.amp_definitely_not_available()
|
||||
):
|
||||
warnings.warn(
|
||||
"User provided device_type of 'cuda', but CUDA is not available. Disabling"
|
||||
)
|
||||
enabled = False
|
||||
if cache_enabled is not None:
|
||||
self._cache_enabled = cache_enabled
|
||||
self._cache_enabled = (
|
||||
torch.is_autocast_cache_enabled()
|
||||
if cache_enabled is None
|
||||
else cache_enabled
|
||||
)
|
||||
|
||||
if self.device == "cpu":
|
||||
supported_dtype = [torch.bfloat16, torch.float16]
|
||||
if self.fast_dtype not in supported_dtype and enabled:
|
||||
error_message = "In CPU autocast, but the target dtype is not supported. Disabling autocast.\n"
|
||||
error_message += "CPU Autocast only supports dtype of "
|
||||
error_message += (
|
||||
", ".join(str(dtype) for dtype in supported_dtype) + " currently."
|
||||
)
|
||||
warnings.warn(error_message)
|
||||
enabled = False
|
||||
elif self.device == "mtia":
|
||||
supported_dtype = [torch.bfloat16, torch.float16]
|
||||
if self.fast_dtype not in supported_dtype:
|
||||
error_message = "In MTIA autocast, but the target dtype is not supported. Disabling autocast.\n"
|
||||
error_message += "MTIA Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
|
||||
warnings.warn(error_message)
|
||||
enabled = False
|
||||
elif self.device == "maia":
|
||||
supported_dtype = [torch.bfloat16, torch.float16]
|
||||
if self.fast_dtype not in supported_dtype:
|
||||
error_message = "In MAIA autocast, but the target dtype is not supported. Disabling autocast.\n"
|
||||
error_message += "MAIA Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
|
||||
warnings.warn(error_message)
|
||||
enabled = False
|
||||
elif self.device == "xpu":
|
||||
supported_dtype = [torch.bfloat16, torch.float16]
|
||||
if self.fast_dtype not in supported_dtype:
|
||||
error_message = "In XPU autocast, but the target dtype is not supported. Disabling autocast.\n"
|
||||
error_message += "XPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
|
||||
warnings.warn(error_message)
|
||||
enabled = False
|
||||
elif self.device == "ipu":
|
||||
supported_dtypes = [torch.bfloat16, torch.float16]
|
||||
if self.fast_dtype not in supported_dtypes:
|
||||
error_message = "In IPU autocast, but the target dtype is not supported. Disabling autocast.\n"
|
||||
error_message += "IPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
|
||||
warnings.warn(error_message)
|
||||
enabled = False
|
||||
elif self.device == "hpu":
|
||||
supported_dtype = [torch.bfloat16, torch.float16]
|
||||
if self.fast_dtype not in supported_dtype:
|
||||
error_message = "In HPU autocast, but the target dtype is not supported. Disabling autocast.\n"
|
||||
error_message += "HPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
|
||||
warnings.warn(error_message)
|
||||
enabled = False
|
||||
elif self.device == self.custom_backend_name:
|
||||
supported_dtype = self.custom_device_mod.get_amp_supported_dtype()
|
||||
if self.fast_dtype not in supported_dtype:
|
||||
error_message = f"In {self.custom_backend_name} autocast, but the target dtype {self.fast_dtype} is not supported. "
|
||||
error_message += f"Disabling autocast.\n {self.custom_backend_name} Autocast only supports dtypes of "
|
||||
error_message += (
|
||||
", ".join(str(dtype) for dtype in supported_dtype) + " currently."
|
||||
)
|
||||
warnings.warn(error_message)
|
||||
enabled = False
|
||||
elif self.device == "cuda":
|
||||
if (
|
||||
enabled
|
||||
and self.fast_dtype == torch.bfloat16
|
||||
and not torch.cuda.is_bf16_supported()
|
||||
):
|
||||
raise RuntimeError(
|
||||
"Current CUDA Device does not support bfloat16. Please switch dtype to float16."
|
||||
)
|
||||
elif self.device == "mps":
|
||||
supported_dtype = [torch.bfloat16, torch.float16]
|
||||
if self.fast_dtype not in supported_dtype:
|
||||
device_name = (
|
||||
self.device
|
||||
if self.device == self.custom_backend_name
|
||||
else self.device.upper()
|
||||
)
|
||||
if enabled:
|
||||
# Special case for CUDA AMP and bfloat16 support
|
||||
if self.device == "cuda":
|
||||
if torch.cuda.amp.common.amp_definitely_not_available():
|
||||
warnings.warn(
|
||||
"CUDA is not available or torch_xla is imported. AMP disabled."
|
||||
)
|
||||
enabled = False
|
||||
elif (
|
||||
self.fast_dtype == torch.bfloat16
|
||||
and not torch.cuda.is_bf16_supported()
|
||||
):
|
||||
raise RuntimeError(
|
||||
"Current CUDA Device does not support bfloat16. Please switch dtype to float16."
|
||||
)
|
||||
elif self.fast_dtype not in device_supported_dtypes:
|
||||
error_message = (
|
||||
"In MPS autocast, but the target dtype is not supported. Disabling autocast.\n"
|
||||
"MPS Autocast only supports dtype of torch.bfloat16 and torch.float16 currently."
|
||||
f"In {device_name} autocast, but the target dtype is not supported. Disabling autocast.\n"
|
||||
f"{device_name} Autocast only supports dtypes of "
|
||||
+ ", ".join(map(str, device_supported_dtypes))
|
||||
+ " currently."
|
||||
)
|
||||
warnings.warn(error_message)
|
||||
enabled = False
|
||||
elif self.fast_dtype == torch.bfloat16:
|
||||
if not torch.backends.mps.is_macos_or_newer(14, 0):
|
||||
# Special case for MPS bfloat16 support on macOS < 14
|
||||
if (
|
||||
self.device == "mps"
|
||||
and self.fast_dtype == torch.bfloat16
|
||||
and not torch.backends.mps.is_macos_or_newer(14, 0)
|
||||
):
|
||||
error_message = (
|
||||
"In MPS autocast, but the target dtype torch.bfloat16 is not supported "
|
||||
"on macOS versions below 14. Disabling autocast."
|
||||
)
|
||||
warnings.warn(error_message)
|
||||
enabled = False
|
||||
elif self.device == "xla":
|
||||
supported_dtype = [torch.float16, torch.bfloat16]
|
||||
if self.fast_dtype not in supported_dtype:
|
||||
error_message = "In XLA autocast, but the target dtype is not supported. Disabling autocast.\n"
|
||||
error_message += (
|
||||
"XLA Autocast only supports dtype of torch.bfloat16 currently."
|
||||
)
|
||||
warnings.warn(error_message)
|
||||
enabled = False
|
||||
self._enabled = enabled
|
||||
|
||||
def __enter__(self):
|
||||
|
Reference in New Issue
Block a user