[Easy][AMP] Refactor the AMP logic for getting dtype (#162796)

As the title stated.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162796
Approved by: https://github.com/ezyang
This commit is contained in:
FFFrog
2025-09-12 19:38:45 +08:00
committed by PyTorch MergeBot
parent 9ba918082a
commit d8cbbc0f70

View File

@ -232,11 +232,11 @@ class autocast:
)
if dtype is None:
dtype = torch.get_autocast_dtype(device_type)
self.fast_dtype = dtype
if torch._jit_internal.is_scripting():
self._enabled = enabled
self.device = device_type
self.fast_dtype = dtype
assert dtype is not None
assert self.fast_dtype is not None
return
self.device = device_type
if not is_autocast_available(self.device):
@ -244,7 +244,6 @@ class autocast:
f"User specified an unsupported autocast device_type '{self.device}'"
)
self.custom_backend_name = torch._C._get_privateuse1_backend_name()
self.fast_dtype = torch.get_autocast_dtype(self.device)
if self.device == self.custom_backend_name:
necessary_funcs = [
"get_amp_supported_dtype",
@ -271,8 +270,6 @@ class autocast:
"User provided device_type of 'cuda', but CUDA is not available. Disabling"
)
enabled = False
if dtype is not None:
self.fast_dtype = dtype
if cache_enabled is not None:
self._cache_enabled = cache_enabled