mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Revert "[AMP][Refactor] Simplify dtype support logic in autocast context manager (#163446)"
This reverts commit 960b0d5f0d0efb1f1962bddcf62e2a698e26edd2. Reverted https://github.com/pytorch/pytorch/pull/163446 on behalf of https://github.com/izaitsevfb due to breaks autocast tests on linux and mac ([comment](https://github.com/pytorch/pytorch/pull/163446#issuecomment-3390688642))
This commit is contained in:
@ -346,7 +346,7 @@ class TestAutocastMPS(TestCase):
|
||||
def test_mps_autocast_error_message(self):
|
||||
with self.assertWarnsRegex(
|
||||
UserWarning,
|
||||
"MPS Autocast only supports dtypes of torch.bfloat16, torch.float16 currently.",
|
||||
"MPS Autocast only supports dtype of torch.bfloat16 and torch.float16 currently.",
|
||||
):
|
||||
with torch.autocast(device_type="mps", dtype=torch.float32):
|
||||
_ = torch.ones(10)
|
||||
|
@ -230,9 +230,9 @@ class autocast:
|
||||
raise ValueError(
|
||||
f"Expected `device_type` of type `str`, got: `{type(device_type)}`"
|
||||
)
|
||||
self.fast_dtype = (
|
||||
torch.get_autocast_dtype(device_type) if dtype is None else dtype
|
||||
)
|
||||
if dtype is None:
|
||||
dtype = torch.get_autocast_dtype(device_type)
|
||||
self.fast_dtype = dtype
|
||||
if torch._jit_internal.is_scripting():
|
||||
self._enabled = enabled
|
||||
self.device = device_type
|
||||
@ -243,9 +243,6 @@ class autocast:
|
||||
raise RuntimeError(
|
||||
f"User specified an unsupported autocast device_type '{self.device}'"
|
||||
)
|
||||
|
||||
device_supported_dtypes = [torch.bfloat16, torch.float16]
|
||||
|
||||
self.custom_backend_name = torch._C._get_privateuse1_backend_name()
|
||||
if self.device == self.custom_backend_name:
|
||||
necessary_funcs = [
|
||||
@ -262,55 +259,110 @@ class autocast:
|
||||
assert hasattr(self.custom_device_mod, func), (
|
||||
message + f"But the func `{func}` is missing. \n"
|
||||
)
|
||||
device_supported_dtypes = self.custom_device_mod.get_amp_supported_dtype()
|
||||
|
||||
self._cache_enabled = (
|
||||
torch.is_autocast_cache_enabled()
|
||||
if cache_enabled is None
|
||||
else cache_enabled
|
||||
)
|
||||
self._cache_enabled = torch.is_autocast_cache_enabled()
|
||||
if (
|
||||
enabled
|
||||
and self.device == "cuda"
|
||||
and torch.cuda.amp.common.amp_definitely_not_available()
|
||||
):
|
||||
warnings.warn(
|
||||
"User provided device_type of 'cuda', but CUDA is not available. Disabling"
|
||||
)
|
||||
enabled = False
|
||||
if cache_enabled is not None:
|
||||
self._cache_enabled = cache_enabled
|
||||
|
||||
device_name = (
|
||||
self.device
|
||||
if self.device == self.custom_backend_name
|
||||
else self.device.upper()
|
||||
)
|
||||
if enabled:
|
||||
# Special case for CUDA AMP and bfloat16 support
|
||||
if self.device == "cuda":
|
||||
if torch.cuda.amp.common.amp_definitely_not_available():
|
||||
warnings.warn(
|
||||
"CUDA is not available or torch_xla is imported. AMP disabled."
|
||||
)
|
||||
enabled = False
|
||||
elif (
|
||||
self.fast_dtype == torch.bfloat16
|
||||
and not torch.cuda.is_bf16_supported()
|
||||
):
|
||||
raise RuntimeError(
|
||||
"Current CUDA Device does not support bfloat16. Please switch dtype to float16."
|
||||
)
|
||||
elif self.fast_dtype not in device_supported_dtypes:
|
||||
error_message = (
|
||||
f"In {device_name} autocast, but the target dtype is not supported. Disabling autocast.\n"
|
||||
f"{device_name} Autocast only supports dtypes of "
|
||||
+ ", ".join(map(str, device_supported_dtypes))
|
||||
+ " currently."
|
||||
if self.device == "cpu":
|
||||
supported_dtype = [torch.bfloat16, torch.float16]
|
||||
if self.fast_dtype not in supported_dtype and enabled:
|
||||
error_message = "In CPU autocast, but the target dtype is not supported. Disabling autocast.\n"
|
||||
error_message += "CPU Autocast only supports dtype of "
|
||||
error_message += (
|
||||
", ".join(str(dtype) for dtype in supported_dtype) + " currently."
|
||||
)
|
||||
warnings.warn(error_message)
|
||||
enabled = False
|
||||
# Special case for MPS bfloat16 support on macOS < 14
|
||||
if (
|
||||
self.device == "mps"
|
||||
and self.fast_dtype == torch.bfloat16
|
||||
and not torch.backends.mps.is_macos_or_newer(14, 0)
|
||||
):
|
||||
elif self.device == "mtia":
|
||||
supported_dtype = [torch.bfloat16, torch.float16]
|
||||
if self.fast_dtype not in supported_dtype:
|
||||
error_message = "In MTIA autocast, but the target dtype is not supported. Disabling autocast.\n"
|
||||
error_message += "MTIA Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
|
||||
warnings.warn(error_message)
|
||||
enabled = False
|
||||
elif self.device == "maia":
|
||||
supported_dtype = [torch.bfloat16, torch.float16]
|
||||
if self.fast_dtype not in supported_dtype:
|
||||
error_message = "In MAIA autocast, but the target dtype is not supported. Disabling autocast.\n"
|
||||
error_message += "MAIA Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
|
||||
warnings.warn(error_message)
|
||||
enabled = False
|
||||
elif self.device == "xpu":
|
||||
supported_dtype = [torch.bfloat16, torch.float16]
|
||||
if self.fast_dtype not in supported_dtype:
|
||||
error_message = "In XPU autocast, but the target dtype is not supported. Disabling autocast.\n"
|
||||
error_message += "XPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
|
||||
warnings.warn(error_message)
|
||||
enabled = False
|
||||
elif self.device == "ipu":
|
||||
supported_dtypes = [torch.bfloat16, torch.float16]
|
||||
if self.fast_dtype not in supported_dtypes:
|
||||
error_message = "In IPU autocast, but the target dtype is not supported. Disabling autocast.\n"
|
||||
error_message += "IPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
|
||||
warnings.warn(error_message)
|
||||
enabled = False
|
||||
elif self.device == "hpu":
|
||||
supported_dtype = [torch.bfloat16, torch.float16]
|
||||
if self.fast_dtype not in supported_dtype:
|
||||
error_message = "In HPU autocast, but the target dtype is not supported. Disabling autocast.\n"
|
||||
error_message += "HPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently."
|
||||
warnings.warn(error_message)
|
||||
enabled = False
|
||||
elif self.device == self.custom_backend_name:
|
||||
supported_dtype = self.custom_device_mod.get_amp_supported_dtype()
|
||||
if self.fast_dtype not in supported_dtype:
|
||||
error_message = f"In {self.custom_backend_name} autocast, but the target dtype {self.fast_dtype} is not supported. "
|
||||
error_message += f"Disabling autocast.\n {self.custom_backend_name} Autocast only supports dtypes of "
|
||||
error_message += (
|
||||
", ".join(str(dtype) for dtype in supported_dtype) + " currently."
|
||||
)
|
||||
warnings.warn(error_message)
|
||||
enabled = False
|
||||
elif self.device == "cuda":
|
||||
if (
|
||||
enabled
|
||||
and self.fast_dtype == torch.bfloat16
|
||||
and not torch.cuda.is_bf16_supported()
|
||||
):
|
||||
raise RuntimeError(
|
||||
"Current CUDA Device does not support bfloat16. Please switch dtype to float16."
|
||||
)
|
||||
elif self.device == "mps":
|
||||
supported_dtype = [torch.bfloat16, torch.float16]
|
||||
if self.fast_dtype not in supported_dtype:
|
||||
error_message = (
|
||||
"In MPS autocast, but the target dtype is not supported. Disabling autocast.\n"
|
||||
"MPS Autocast only supports dtype of torch.bfloat16 and torch.float16 currently."
|
||||
)
|
||||
warnings.warn(error_message)
|
||||
enabled = False
|
||||
elif self.fast_dtype == torch.bfloat16:
|
||||
if not torch.backends.mps.is_macos_or_newer(14, 0):
|
||||
error_message = (
|
||||
"In MPS autocast, but the target dtype torch.bfloat16 is not supported "
|
||||
"on macOS versions below 14. Disabling autocast."
|
||||
)
|
||||
warnings.warn(error_message)
|
||||
enabled = False
|
||||
elif self.device == "xla":
|
||||
supported_dtype = [torch.float16, torch.bfloat16]
|
||||
if self.fast_dtype not in supported_dtype:
|
||||
error_message = "In XLA autocast, but the target dtype is not supported. Disabling autocast.\n"
|
||||
error_message += (
|
||||
"XLA Autocast only supports dtype of torch.bfloat16 currently."
|
||||
)
|
||||
warnings.warn(error_message)
|
||||
enabled = False
|
||||
self._enabled = enabled
|
||||
|
||||
def __enter__(self):
|
||||
|
Reference in New Issue
Block a user