Revert "[AMP][Refactor] Simplify dtype support logic in autocast context manager (#163446)"

This reverts commit 960b0d5f0d0efb1f1962bddcf62e2a698e26edd2.

Reverted https://github.com/pytorch/pytorch/pull/163446 on behalf of https://github.com/izaitsevfb due to breaks autocast tests on linux and mac ([comment](https://github.com/pytorch/pytorch/pull/163446#issuecomment-3390688642))
This commit is contained in:
PyTorch MergeBot
2025-10-10 15:12:46 +00:00
parent 55f01a48af
commit 9420944033
2 changed files with 97 additions and 45 deletions

View File

@ -346,7 +346,7 @@ class TestAutocastMPS(TestCase):
def test_mps_autocast_error_message(self): def test_mps_autocast_error_message(self):
with self.assertWarnsRegex( with self.assertWarnsRegex(
UserWarning, UserWarning,
"MPS Autocast only supports dtypes of torch.bfloat16, torch.float16 currently.", "MPS Autocast only supports dtype of torch.bfloat16 and torch.float16 currently.",
): ):
with torch.autocast(device_type="mps", dtype=torch.float32): with torch.autocast(device_type="mps", dtype=torch.float32):
_ = torch.ones(10) _ = torch.ones(10)

View File

@ -230,9 +230,9 @@ class autocast:
raise ValueError( raise ValueError(
f"Expected `device_type` of type `str`, got: `{type(device_type)}`" f"Expected `device_type` of type `str`, got: `{type(device_type)}`"
) )
self.fast_dtype = ( if dtype is None:
torch.get_autocast_dtype(device_type) if dtype is None else dtype dtype = torch.get_autocast_dtype(device_type)
) self.fast_dtype = dtype
if torch._jit_internal.is_scripting(): if torch._jit_internal.is_scripting():
self._enabled = enabled self._enabled = enabled
self.device = device_type self.device = device_type
@ -243,9 +243,6 @@ class autocast:
raise RuntimeError( raise RuntimeError(
f"User specified an unsupported autocast device_type '{self.device}'" f"User specified an unsupported autocast device_type '{self.device}'"
) )
device_supported_dtypes = [torch.bfloat16, torch.float16]
self.custom_backend_name = torch._C._get_privateuse1_backend_name() self.custom_backend_name = torch._C._get_privateuse1_backend_name()
if self.device == self.custom_backend_name: if self.device == self.custom_backend_name:
necessary_funcs = [ necessary_funcs = [
@ -262,55 +259,110 @@ class autocast:
assert hasattr(self.custom_device_mod, func), ( assert hasattr(self.custom_device_mod, func), (
message + f"But the func `{func}` is missing. \n" message + f"But the func `{func}` is missing. \n"
) )
device_supported_dtypes = self.custom_device_mod.get_amp_supported_dtype()
self._cache_enabled = ( self._cache_enabled = torch.is_autocast_cache_enabled()
torch.is_autocast_cache_enabled() if (
if cache_enabled is None enabled
else cache_enabled and self.device == "cuda"
) and torch.cuda.amp.common.amp_definitely_not_available()
):
warnings.warn(
"User provided device_type of 'cuda', but CUDA is not available. Disabling"
)
enabled = False
if cache_enabled is not None:
self._cache_enabled = cache_enabled
device_name = ( if self.device == "cpu":
self.device supported_dtype = [torch.bfloat16, torch.float16]
if self.device == self.custom_backend_name if self.fast_dtype not in supported_dtype and enabled:
else self.device.upper() error_message = "In CPU autocast, but the target dtype is not supported. Disabling autocast.\n"
) error_message += "CPU Autocast only supports dtype of "
if enabled: error_message += (
# Special case for CUDA AMP and bfloat16 support ", ".join(str(dtype) for dtype in supported_dtype) + " currently."
if self.device == "cuda":
if torch.cuda.amp.common.amp_definitely_not_available():
warnings.warn(
"CUDA is not available or torch_xla is imported. AMP disabled."
)
enabled = False
elif (
self.fast_dtype == torch.bfloat16
and not torch.cuda.is_bf16_supported()
):
raise RuntimeError(
"Current CUDA Device does not support bfloat16. Please switch dtype to float16."
)
elif self.fast_dtype not in device_supported_dtypes:
error_message = (
f"In {device_name} autocast, but the target dtype is not supported. Disabling autocast.\n"
f"{device_name} Autocast only supports dtypes of "
+ ", ".join(map(str, device_supported_dtypes))
+ " currently."
) )
warnings.warn(error_message) warnings.warn(error_message)
enabled = False enabled = False
# Special case for MPS bfloat16 support on macOS < 14 elif self.device == "mtia":
if ( supported_dtype = [torch.bfloat16, torch.float16]
self.device == "mps" if self.fast_dtype not in supported_dtype:
and self.fast_dtype == torch.bfloat16 error_message = "In MTIA autocast, but the target dtype is not supported. Disabling autocast.\n"
and not torch.backends.mps.is_macos_or_newer(14, 0) error_message += "MTIA Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
): warnings.warn(error_message)
enabled = False
elif self.device == "maia":
supported_dtype = [torch.bfloat16, torch.float16]
if self.fast_dtype not in supported_dtype:
error_message = "In MAIA autocast, but the target dtype is not supported. Disabling autocast.\n"
error_message += "MAIA Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
warnings.warn(error_message)
enabled = False
elif self.device == "xpu":
supported_dtype = [torch.bfloat16, torch.float16]
if self.fast_dtype not in supported_dtype:
error_message = "In XPU autocast, but the target dtype is not supported. Disabling autocast.\n"
error_message += "XPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
warnings.warn(error_message)
enabled = False
elif self.device == "ipu":
supported_dtypes = [torch.bfloat16, torch.float16]
if self.fast_dtype not in supported_dtypes:
error_message = "In IPU autocast, but the target dtype is not supported. Disabling autocast.\n"
error_message += "IPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
warnings.warn(error_message)
enabled = False
elif self.device == "hpu":
supported_dtype = [torch.bfloat16, torch.float16]
if self.fast_dtype not in supported_dtype:
error_message = "In HPU autocast, but the target dtype is not supported. Disabling autocast.\n"
error_message += "HPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
warnings.warn(error_message)
enabled = False
elif self.device == self.custom_backend_name:
supported_dtype = self.custom_device_mod.get_amp_supported_dtype()
if self.fast_dtype not in supported_dtype:
error_message = f"In {self.custom_backend_name} autocast, but the target dtype {self.fast_dtype} is not supported. "
error_message += f"Disabling autocast.\n {self.custom_backend_name} Autocast only supports dtypes of "
error_message += (
", ".join(str(dtype) for dtype in supported_dtype) + " currently."
)
warnings.warn(error_message)
enabled = False
elif self.device == "cuda":
if (
enabled
and self.fast_dtype == torch.bfloat16
and not torch.cuda.is_bf16_supported()
):
raise RuntimeError(
"Current CUDA Device does not support bfloat16. Please switch dtype to float16."
)
elif self.device == "mps":
supported_dtype = [torch.bfloat16, torch.float16]
if self.fast_dtype not in supported_dtype:
error_message = (
"In MPS autocast, but the target dtype is not supported. Disabling autocast.\n"
"MPS Autocast only supports dtype of torch.bfloat16 and torch.float16 currently."
)
warnings.warn(error_message)
enabled = False
elif self.fast_dtype == torch.bfloat16:
if not torch.backends.mps.is_macos_or_newer(14, 0):
error_message = ( error_message = (
"In MPS autocast, but the target dtype torch.bfloat16 is not supported " "In MPS autocast, but the target dtype torch.bfloat16 is not supported "
"on macOS versions below 14. Disabling autocast." "on macOS versions below 14. Disabling autocast."
) )
warnings.warn(error_message) warnings.warn(error_message)
enabled = False enabled = False
elif self.device == "xla":
supported_dtype = [torch.float16, torch.bfloat16]
if self.fast_dtype not in supported_dtype:
error_message = "In XLA autocast, but the target dtype is not supported. Disabling autocast.\n"
error_message += (
"XLA Autocast only supports dtype of torch.bfloat16 currently."
)
warnings.warn(error_message)
enabled = False
self._enabled = enabled self._enabled = enabled
def __enter__(self): def __enter__(self):