Revert "[AMP][Refactor] Simplify dtype support logic in autocast context manager (#163446)"

This reverts commit 960b0d5f0d0efb1f1962bddcf62e2a698e26edd2.

Reverted https://github.com/pytorch/pytorch/pull/163446 on behalf of https://github.com/izaitsevfb due to breaks autocast tests on linux and mac ([comment](https://github.com/pytorch/pytorch/pull/163446#issuecomment-3390688642))
This commit is contained in:
PyTorch MergeBot
2025-10-10 15:12:46 +00:00
parent 55f01a48af
commit 9420944033
2 changed files with 97 additions and 45 deletions

View File

@ -346,7 +346,7 @@ class TestAutocastMPS(TestCase):
def test_mps_autocast_error_message(self):
with self.assertWarnsRegex(
UserWarning,
"MPS Autocast only supports dtypes of torch.bfloat16, torch.float16 currently.",
"MPS Autocast only supports dtype of torch.bfloat16 and torch.float16 currently.",
):
with torch.autocast(device_type="mps", dtype=torch.float32):
_ = torch.ones(10)

View File

@ -230,9 +230,9 @@ class autocast:
raise ValueError(
f"Expected `device_type` of type `str`, got: `{type(device_type)}`"
)
self.fast_dtype = (
torch.get_autocast_dtype(device_type) if dtype is None else dtype
)
if dtype is None:
dtype = torch.get_autocast_dtype(device_type)
self.fast_dtype = dtype
if torch._jit_internal.is_scripting():
self._enabled = enabled
self.device = device_type
@ -243,9 +243,6 @@ class autocast:
raise RuntimeError(
f"User specified an unsupported autocast device_type '{self.device}'"
)
device_supported_dtypes = [torch.bfloat16, torch.float16]
self.custom_backend_name = torch._C._get_privateuse1_backend_name()
if self.device == self.custom_backend_name:
necessary_funcs = [
@ -262,55 +259,110 @@ class autocast:
assert hasattr(self.custom_device_mod, func), (
message + f"But the func `{func}` is missing. \n"
)
device_supported_dtypes = self.custom_device_mod.get_amp_supported_dtype()
self._cache_enabled = (
torch.is_autocast_cache_enabled()
if cache_enabled is None
else cache_enabled
)
self._cache_enabled = torch.is_autocast_cache_enabled()
if (
enabled
and self.device == "cuda"
and torch.cuda.amp.common.amp_definitely_not_available()
):
warnings.warn(
"User provided device_type of 'cuda', but CUDA is not available. Disabling"
)
enabled = False
if cache_enabled is not None:
self._cache_enabled = cache_enabled
device_name = (
self.device
if self.device == self.custom_backend_name
else self.device.upper()
)
if enabled:
# Special case for CUDA AMP and bfloat16 support
if self.device == "cuda":
if torch.cuda.amp.common.amp_definitely_not_available():
warnings.warn(
"CUDA is not available or torch_xla is imported. AMP disabled."
)
enabled = False
elif (
self.fast_dtype == torch.bfloat16
and not torch.cuda.is_bf16_supported()
):
raise RuntimeError(
"Current CUDA Device does not support bfloat16. Please switch dtype to float16."
)
elif self.fast_dtype not in device_supported_dtypes:
error_message = (
f"In {device_name} autocast, but the target dtype is not supported. Disabling autocast.\n"
f"{device_name} Autocast only supports dtypes of "
+ ", ".join(map(str, device_supported_dtypes))
+ " currently."
if self.device == "cpu":
supported_dtype = [torch.bfloat16, torch.float16]
if self.fast_dtype not in supported_dtype and enabled:
error_message = "In CPU autocast, but the target dtype is not supported. Disabling autocast.\n"
error_message += "CPU Autocast only supports dtype of "
error_message += (
", ".join(str(dtype) for dtype in supported_dtype) + " currently."
)
warnings.warn(error_message)
enabled = False
# Special case for MPS bfloat16 support on macOS < 14
if (
self.device == "mps"
and self.fast_dtype == torch.bfloat16
and not torch.backends.mps.is_macos_or_newer(14, 0)
):
elif self.device == "mtia":
supported_dtype = [torch.bfloat16, torch.float16]
if self.fast_dtype not in supported_dtype:
error_message = "In MTIA autocast, but the target dtype is not supported. Disabling autocast.\n"
error_message += "MTIA Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
warnings.warn(error_message)
enabled = False
elif self.device == "maia":
supported_dtype = [torch.bfloat16, torch.float16]
if self.fast_dtype not in supported_dtype:
error_message = "In MAIA autocast, but the target dtype is not supported. Disabling autocast.\n"
error_message += "MAIA Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
warnings.warn(error_message)
enabled = False
elif self.device == "xpu":
supported_dtype = [torch.bfloat16, torch.float16]
if self.fast_dtype not in supported_dtype:
error_message = "In XPU autocast, but the target dtype is not supported. Disabling autocast.\n"
error_message += "XPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
warnings.warn(error_message)
enabled = False
elif self.device == "ipu":
supported_dtypes = [torch.bfloat16, torch.float16]
if self.fast_dtype not in supported_dtypes:
error_message = "In IPU autocast, but the target dtype is not supported. Disabling autocast.\n"
error_message += "IPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
warnings.warn(error_message)
enabled = False
elif self.device == "hpu":
supported_dtype = [torch.bfloat16, torch.float16]
if self.fast_dtype not in supported_dtype:
error_message = "In HPU autocast, but the target dtype is not supported. Disabling autocast.\n"
error_message += "HPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
warnings.warn(error_message)
enabled = False
elif self.device == self.custom_backend_name:
supported_dtype = self.custom_device_mod.get_amp_supported_dtype()
if self.fast_dtype not in supported_dtype:
error_message = f"In {self.custom_backend_name} autocast, but the target dtype {self.fast_dtype} is not supported. "
error_message += f"Disabling autocast.\n {self.custom_backend_name} Autocast only supports dtypes of "
error_message += (
", ".join(str(dtype) for dtype in supported_dtype) + " currently."
)
warnings.warn(error_message)
enabled = False
elif self.device == "cuda":
if (
enabled
and self.fast_dtype == torch.bfloat16
and not torch.cuda.is_bf16_supported()
):
raise RuntimeError(
"Current CUDA Device does not support bfloat16. Please switch dtype to float16."
)
elif self.device == "mps":
supported_dtype = [torch.bfloat16, torch.float16]
if self.fast_dtype not in supported_dtype:
error_message = (
"In MPS autocast, but the target dtype is not supported. Disabling autocast.\n"
"MPS Autocast only supports dtype of torch.bfloat16 and torch.float16 currently."
)
warnings.warn(error_message)
enabled = False
elif self.fast_dtype == torch.bfloat16:
if not torch.backends.mps.is_macos_or_newer(14, 0):
error_message = (
"In MPS autocast, but the target dtype torch.bfloat16 is not supported "
"on macOS versions below 14. Disabling autocast."
)
warnings.warn(error_message)
enabled = False
elif self.device == "xla":
supported_dtype = [torch.float16, torch.bfloat16]
if self.fast_dtype not in supported_dtype:
error_message = "In XLA autocast, but the target dtype is not supported. Disabling autocast.\n"
error_message += (
"XLA Autocast only supports dtype of torch.bfloat16 currently."
)
warnings.warn(error_message)
enabled = False
self._enabled = enabled
def __enter__(self):