refactor autocast python APIs (#124479)

# Motivation
Refactor autocast usage scenario in `torch/amp/autocast_mode.py` and `torch/utils/checkpoint.py` to fix the bug - convention conflict between `torch.xxx.get_autocast_xxx_dtype` defined in `autocast_mode.py` and `torch.xxx.get_autocast_dtype` defined in `checkpoint.py`.

# Solution
Use device-agnostic APIs like `torch.get_autocast_dtype`, ..., instead.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124479
Approved by: https://github.com/jgong5, https://github.com/gujinghui, https://github.com/EikanWang, https://github.com/albanD
ghstack dependencies: #124359
This commit is contained in:
Yu, Guangye
2024-04-23 09:34:46 +00:00
committed by PyTorch MergeBot
parent f01275934b
commit cdc66e9dc3
7 changed files with 53 additions and 128 deletions

View File

@ -174,12 +174,20 @@ inline DispatchKey get_autocast_dispatch_key_from_device_type(
} }
} }
inline at::ScalarType get_lower_precision_fp_from_device_type( inline bool is_autocast_available(c10::DeviceType device_type) {
c10::DeviceType device_type) {
if (device_type == at::kCPU || device_type == at::kCUDA || if (device_type == at::kCPU || device_type == at::kCUDA ||
device_type == at::kXPU || device_type == at::kIPU || device_type == at::kXPU || device_type == at::kIPU ||
device_type == at::kHPU || device_type == at::kXLA || device_type == at::kHPU || device_type == at::kXLA ||
device_type == at::kPrivateUse1) { device_type == at::kPrivateUse1) {
return true;
} else {
return false;
}
}
inline at::ScalarType get_lower_precision_fp_from_device_type(
c10::DeviceType device_type) {
if (is_autocast_available(device_type)) {
return get_autocast_dtype(device_type); return get_autocast_dtype(device_type);
} else { } else {
throw std::runtime_error( throw std::runtime_error(

View File

@ -336,7 +336,7 @@ class TestTorchAutocast(TestCase):
def test_invalid_device(self): def test_invalid_device(self):
dev = "not a real device" dev = "not a real device"
msg = f"unsupported autocast device_type '{dev}'" msg = f"Invalid device string: '{dev}'"
with self.assertRaisesRegex(RuntimeError, msg): with self.assertRaisesRegex(RuntimeError, msg):
with torch.autocast(device_type=dev): with torch.autocast(device_type=dev):
_ = torch.tensor(1) _ = torch.tensor(1)

View File

@ -1301,6 +1301,7 @@ def clear_autocast_cache() -> None: ...
def set_autocast_cpu_enabled(enabled: _bool) -> None: ... def set_autocast_cpu_enabled(enabled: _bool) -> None: ...
def is_autocast_cpu_enabled() -> _bool: ... def is_autocast_cpu_enabled() -> _bool: ...
def _is_any_autocast_enabled() -> _bool: ... def _is_any_autocast_enabled() -> _bool: ...
def _is_autocast_available(device_type: str) -> _bool: ...
def set_autocast_cpu_dtype(dtype: _dtype) -> None: ... def set_autocast_cpu_dtype(dtype: _dtype) -> None: ...
def set_autocast_gpu_dtype(dtype: _dtype) -> None: ... def set_autocast_gpu_dtype(dtype: _dtype) -> None: ...
def get_autocast_cpu_dtype() -> _dtype: ... def get_autocast_cpu_dtype() -> _dtype: ...

View File

@ -199,35 +199,20 @@ class autocast:
assert dtype is not None assert dtype is not None
return return
self.device = device_type self.device = device_type
if not torch._C._is_autocast_available(self.device):
raise RuntimeError(
f"User specified an unsupported autocast device_type '{self.device}'"
)
self.custom_backend_name = torch._C._get_privateuse1_backend_name() self.custom_backend_name = torch._C._get_privateuse1_backend_name()
if self.device == "cuda": self.fast_dtype = torch.get_autocast_dtype(self.device)
self.fast_dtype = torch.get_autocast_gpu_dtype() if self.device == self.custom_backend_name:
elif self.device == "cpu":
self.fast_dtype = torch.get_autocast_cpu_dtype()
elif self.device == "xpu":
self.fast_dtype = torch.xpu.get_autocast_xpu_dtype() # type: ignore[attr-defined]
elif self.device == "ipu":
self.fast_dtype = torch.get_autocast_ipu_dtype() # type: ignore[attr-defined]
elif self.device == "hpu":
self.fast_dtype = torch.hpu.get_autocast_hpu_dtype() # type: ignore[attr-defined]
elif self.device == "xla":
self.fast_dtype = torch.get_autocast_xla_dtype() # type: ignore[attr-defined]
elif self.device == self.custom_backend_name:
necessary_funcs = [ necessary_funcs = [
"is_autocast_enabled",
"set_autocast_enabled",
"get_autocast_dtype",
"set_autocast_dtype",
"get_amp_supported_dtype", "get_amp_supported_dtype",
] ]
message = f"Tried to use AMP with the `{self.custom_backend_name}` backend, but the backend has not " message = f"Tried to use AMP with the `{self.custom_backend_name}` backend, but the backend has not "
message += "registered a module or the module miss some necessary funcs. The backend should register " message += "registered a module or the module miss some necessary funcs. The backend should register "
message += "a module by `torch._register_device_module`, and the module must have these funcs: \n" message += "a module by `torch._register_device_module`, and the module must have these funcs: \n"
message += "`is_autocast_enabled() -> bool`, `set_autocast_enabled(bool) -> None`, " message += "`get_amp_supported_dtype() -> List[torch.dtype]`. \n"
message += "`get_autocast_dtype() -> torch.dtype`, `set_autocast_dtype(torch.dtype) "
message += (
"-> None` and `get_amp_supported_dtype() -> List[torch.dtype]`. \n"
)
assert hasattr(torch, self.custom_backend_name), message assert hasattr(torch, self.custom_backend_name), message
self.custom_device_mod = getattr(torch, self.custom_backend_name) self.custom_device_mod = getattr(torch, self.custom_backend_name)
@ -236,11 +221,6 @@ class autocast:
message + f"But the func `{func}` is missing. \n" message + f"But the func `{func}` is missing. \n"
) )
self.fast_dtype = self.custom_device_mod.get_autocast_dtype()
else:
raise RuntimeError(
f"User specified an unsupported autocast device_type '{self.device}'"
)
self._cache_enabled = torch.is_autocast_cache_enabled() self._cache_enabled = torch.is_autocast_cache_enabled()
if ( if (
enabled enabled
@ -323,48 +303,11 @@ class autocast:
return self return self
self.prev_cache_enabled = torch.is_autocast_cache_enabled() self.prev_cache_enabled = torch.is_autocast_cache_enabled()
if self.device == "cpu": self.prev = torch.is_autocast_enabled(self.device)
self.prev = torch.is_autocast_cpu_enabled() self.prev_fastdtype = torch.get_autocast_dtype(self.device)
self.prev_fastdtype = torch.get_autocast_cpu_dtype() torch.set_autocast_enabled(self.device, self._enabled)
torch.set_autocast_cpu_enabled(self._enabled) torch.set_autocast_dtype(self.device, self.fast_dtype) # type: ignore[arg-type]
torch.set_autocast_cpu_dtype(self.fast_dtype) # type: ignore[arg-type] torch.autocast_increment_nesting()
torch.autocast_increment_nesting()
elif self.device == "xpu":
self.prev = torch.xpu.is_autocast_xpu_enabled() # type: ignore[attr-defined]
self.prev_fastdtype = torch.xpu.get_autocast_xpu_dtype() # type: ignore[attr-defined]
torch.xpu.set_autocast_xpu_enabled(self._enabled) # type: ignore[attr-defined]
torch.xpu.set_autocast_xpu_dtype(self.fast_dtype) # type: ignore[attr-defined]
torch.autocast_increment_nesting()
elif self.device == "ipu":
self.prev = torch.is_autocast_ipu_enabled() # type: ignore[attr-defined]
self.prev_fastdtype = torch.get_autocast_ipu_dtype() # type: ignore[attr-defined]
torch.set_autocast_ipu_enabled(self._enabled) # type: ignore[attr-defined]
torch.set_autocast_ipu_dtype(self.fast_dtype) # type: ignore[attr-defined]
torch.autocast_increment_nesting()
elif self.device == "hpu":
self.prev = torch.hpu.is_autocast_hpu_enabled() # type: ignore[attr-defined]
self.prev_fastdtype = torch.hpu.get_autocast_hpu_dtype() # type: ignore[attr-defined]
torch.hpu.set_autocast_hpu_enabled(self._enabled) # type: ignore[attr-defined]
torch.hpu.set_autocast_hpu_dtype(self.fast_dtype) # type: ignore[attr-defined]
torch.autocast_increment_nesting()
elif self.device == "xla":
self.prev = torch.is_autocast_xla_enabled() # type: ignore[attr-defined]
self.prev_fastdtype = torch.get_autocast_xla_dtype() # type: ignore[attr-defined]
torch.set_autocast_xla_enabled(self._enabled) # type: ignore[attr-defined]
torch.set_autocast_xla_dtype(self.fast_dtype) # type: ignore[attr-defined]
torch.autocast_increment_nesting()
elif self.device == self.custom_backend_name:
self.prev = self.custom_device_mod.is_autocast_enabled()
self.prev_fastdtype = self.custom_device_mod.get_autocast_dtype()
self.custom_device_mod.set_autocast_enabled(self._enabled)
self.custom_device_mod.set_autocast_dtype(self.fast_dtype)
torch.autocast_increment_nesting()
else:
self.prev = torch.is_autocast_enabled()
self.prev_fastdtype = torch.get_autocast_gpu_dtype()
torch.set_autocast_gpu_dtype(self.fast_dtype) # type: ignore[arg-type]
torch.set_autocast_enabled(self._enabled)
torch.autocast_increment_nesting()
torch.set_autocast_cache_enabled(self._cache_enabled) torch.set_autocast_cache_enabled(self._cache_enabled)
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): # type: ignore[override] def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): # type: ignore[override]
@ -372,41 +315,10 @@ class autocast:
return return
# Drop the cache when we exit to a nesting level that's outside any instance of autocast. # Drop the cache when we exit to a nesting level that's outside any instance of autocast.
if self.device == "cpu": if torch.autocast_decrement_nesting() == 0:
if torch.autocast_decrement_nesting() == 0: torch.clear_autocast_cache()
torch.clear_autocast_cache() torch.set_autocast_enabled(self.device, self.prev)
torch.set_autocast_cpu_enabled(self.prev) torch.set_autocast_dtype(self.device, self.prev_fastdtype)
torch.set_autocast_cpu_dtype(self.prev_fastdtype)
elif self.device == "xpu":
if torch.autocast_decrement_nesting() == 0:
torch.clear_autocast_cache()
torch.xpu.set_autocast_xpu_enabled(self.prev) # type: ignore[attr-defined]
torch.xpu.set_autocast_xpu_dtype(self.prev_fastdtype) # type: ignore[attr-defined]
elif self.device == "ipu":
if torch.autocast_decrement_nesting() == 0:
torch.clear_autocast_cache()
torch.set_autocast_ipu_enabled(self.prev) # type: ignore[attr-defined]
torch.set_autocast_ipu_dtype(self.prev_fastdtype) # type: ignore[attr-defined]
elif self.device == "hpu":
if torch.autocast_decrement_nesting() == 0:
torch.clear_autocast_cache()
torch.hpu.set_autocast_hpu_enabled(self.prev) # type: ignore[attr-defined]
torch.hpu.set_autocast_hpu_dtype(self.prev_fastdtype) # type: ignore[attr-defined]
elif self.device == "xla":
if torch.autocast_decrement_nesting() == 0:
torch.clear_autocast_cache()
torch.set_autocast_xla_enabled(self.prev) # type: ignore[attr-defined]
torch.set_autocast_xla_dtype(self.prev_fastdtype) # type: ignore[attr-defined]
elif self.device == self.custom_backend_name:
if torch.autocast_decrement_nesting() == 0:
torch.clear_autocast_cache()
self.custom_device_mod.set_autocast_enabled(self.prev)
self.custom_device_mod.set_autocast_dtype(self.prev_fastdtype)
else:
if torch.autocast_decrement_nesting() == 0:
torch.clear_autocast_cache()
torch.set_autocast_enabled(self.prev)
torch.set_autocast_gpu_dtype(self.prev_fastdtype)
torch.set_autocast_cache_enabled(self.prev_cache_enabled) torch.set_autocast_cache_enabled(self.prev_cache_enabled)
return False return False

View File

@ -574,6 +574,24 @@ static PyObject* is_any_autocast_enabled(PyObject* _unused, PyObject* arg) {
END_HANDLE_TH_ERRORS END_HANDLE_TH_ERRORS
} }
static PyObject* is_autocast_available(
PyObject* _unused,
PyObject* args,
PyObject* kwargs) {
HANDLE_TH_ERRORS
static PythonArgParser parser(
{"_is_autocast_available(c10::string_view device_type)"});
ParsedArgs<1> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
auto device_type = at::Device(r.string(0)).type();
if (at::autocast::is_autocast_available(device_type)) {
Py_RETURN_TRUE;
} else {
Py_RETURN_FALSE;
}
END_HANDLE_TH_ERRORS
}
static PyObject* set_autocast_cpu_enabled(PyObject* _unused, PyObject* arg) { static PyObject* set_autocast_cpu_enabled(PyObject* _unused, PyObject* arg) {
HANDLE_TH_ERRORS HANDLE_TH_ERRORS
TORCH_CHECK_TYPE( TORCH_CHECK_TYPE(
@ -1235,6 +1253,10 @@ static PyMethodDef methods[] = { // NOLINT
METH_VARARGS | METH_KEYWORDS, METH_VARARGS | METH_KEYWORDS,
nullptr}, nullptr},
{"_is_any_autocast_enabled", is_any_autocast_enabled, METH_NOARGS, nullptr}, {"_is_any_autocast_enabled", is_any_autocast_enabled, METH_NOARGS, nullptr},
{"_is_autocast_available",
castPyCFunctionWithKeywords(is_autocast_available),
METH_VARARGS | METH_KEYWORDS,
nullptr},
{"clear_autocast_cache", clear_autocast_cache, METH_NOARGS, nullptr}, {"clear_autocast_cache", clear_autocast_cache, METH_NOARGS, nullptr},
{"set_autocast_cpu_enabled", set_autocast_cpu_enabled, METH_O, nullptr}, {"set_autocast_cpu_enabled", set_autocast_cpu_enabled, METH_O, nullptr},
{"is_autocast_cpu_enabled", is_autocast_cpu_enabled, METH_NOARGS, nullptr}, {"is_autocast_cpu_enabled", is_autocast_cpu_enabled, METH_NOARGS, nullptr},

View File

@ -36,20 +36,6 @@ def rename_privateuse1_backend(backend_name: str) -> None:
(1) ``get_amp_supported_dtype() -> List[torch.dtype]`` (1) ``get_amp_supported_dtype() -> List[torch.dtype]``
get the supported dtypes on your "foo" device in AMP, maybe the "foo" device supports one more dtype. get the supported dtypes on your "foo" device in AMP, maybe the "foo" device supports one more dtype.
(2) ``is_autocast_enabled() -> bool``
check the AMP is enabled or not on your "foo" device.
(3) ``get_autocast_dtype() -> torch.dtype``
get the supported dtype on your "foo" device in AMP, which is set by ``set_autocast_dtype`` or the
default dtype, and the default dtype is ``torch.float16``.
(4) ``set_autocast_enabled(bool) -> None``
enable the AMP or not on your "foo" device.
(5) ``set_autocast_dtype(dtype) -> None``
set the supported dtype on your "foo" device in AMP, and the dtype be contained in the dtypes got
from ``get_amp_supported_dtype``.
Note(random): If you want to support to set seed for your device, BackendModule needs to have the following API's: Note(random): If you want to support to set seed for your device, BackendModule needs to have the following API's:
(1) ``_is_in_bad_fork() -> bool`` (1) ``_is_in_bad_fork() -> bool``

View File

@ -194,7 +194,7 @@ def set_device_states(devices, states) -> None:
def _get_autocast_kwargs(device="cuda"): def _get_autocast_kwargs(device="cuda"):
if _supports_autocast(device): if torch._C._is_autocast_available(device):
device_autocast_kwargs = { device_autocast_kwargs = {
"enabled": torch.is_autocast_enabled(device), "enabled": torch.is_autocast_enabled(device),
"dtype": torch.get_autocast_dtype(device), "dtype": torch.get_autocast_dtype(device),
@ -211,10 +211,6 @@ def _get_autocast_kwargs(device="cuda"):
return device_autocast_kwargs, cpu_autocast_kwargs return device_autocast_kwargs, cpu_autocast_kwargs
def _supports_autocast(device):
device_module = _get_device_module(device)
return device == "cuda" or (hasattr(device_module, "is_autocast_enabled")
and hasattr(device_module, "get_autocast_dtype"))
class CheckpointFunction(torch.autograd.Function): class CheckpointFunction(torch.autograd.Function):
@staticmethod @staticmethod
@ -293,7 +289,7 @@ class CheckpointFunction(torch.autograd.Function):
device_autocast_ctx = device_module.amp.autocast( device_autocast_ctx = device_module.amp.autocast(
**ctx.device_autocast_kwargs **ctx.device_autocast_kwargs
) if _supports_autocast(ctx.device) else contextlib.nullcontext() ) if torch._C._is_autocast_available(ctx.device) else contextlib.nullcontext()
with torch.enable_grad(), device_autocast_ctx, \ with torch.enable_grad(), device_autocast_ctx, \
torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs): torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):
outputs = ctx.run_function(*detached_inputs) outputs = ctx.run_function(*detached_inputs)
@ -1400,7 +1396,7 @@ def _checkpoint_without_reentrant_generator(
device_autocast_ctx = device_module.amp.autocast( device_autocast_ctx = device_module.amp.autocast(
**device_autocast_kwargs **device_autocast_kwargs
) if _supports_autocast(device) else contextlib.nullcontext() ) if torch._C._is_autocast_available(device) else contextlib.nullcontext()
with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), \ with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), \
recompute_context: recompute_context:
fn(*args, **kwargs) fn(*args, **kwargs)