mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[Easy][AMP] Refactor the AMP logic for getting dtype (#162796)
As the title stated. Pull Request resolved: https://github.com/pytorch/pytorch/pull/162796 Approved by: https://github.com/ezyang
This commit is contained in:
@ -232,11 +232,11 @@ class autocast:
|
||||
)
|
||||
if dtype is None:
|
||||
dtype = torch.get_autocast_dtype(device_type)
|
||||
self.fast_dtype = dtype
|
||||
if torch._jit_internal.is_scripting():
|
||||
self._enabled = enabled
|
||||
self.device = device_type
|
||||
self.fast_dtype = dtype
|
||||
assert dtype is not None
|
||||
assert self.fast_dtype is not None
|
||||
return
|
||||
self.device = device_type
|
||||
if not is_autocast_available(self.device):
|
||||
@ -244,7 +244,6 @@ class autocast:
|
||||
f"User specified an unsupported autocast device_type '{self.device}'"
|
||||
)
|
||||
self.custom_backend_name = torch._C._get_privateuse1_backend_name()
|
||||
self.fast_dtype = torch.get_autocast_dtype(self.device)
|
||||
if self.device == self.custom_backend_name:
|
||||
necessary_funcs = [
|
||||
"get_amp_supported_dtype",
|
||||
@ -271,8 +270,6 @@ class autocast:
|
||||
"User provided device_type of 'cuda', but CUDA is not available. Disabling"
|
||||
)
|
||||
enabled = False
|
||||
if dtype is not None:
|
||||
self.fast_dtype = dtype
|
||||
if cache_enabled is not None:
|
||||
self._cache_enabled = cache_enabled
|
||||
|
||||
|
Reference in New Issue
Block a user