refactor autocast python APIs (#124479)

# Motivation
Refactor autocast usage scenario in `torch/amp/autocast_mode.py` and `torch/utils/checkpoint.py` to fix the bug - convention conflict between `torch.xxx.get_autocast_xxx_dtype` defined in `autocast_mode.py` and `torch.xxx.get_autocast_dtype` defined in `checkpoint.py`.

# Solution
Use device-agnostic APIs like `torch.get_autocast_dtype`, ..., instead.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124479
Approved by: https://github.com/jgong5, https://github.com/gujinghui, https://github.com/EikanWang, https://github.com/albanD
ghstack dependencies: #124359
This commit is contained in:
Yu, Guangye
2024-04-23 09:34:46 +00:00
committed by PyTorch MergeBot
parent f01275934b
commit cdc66e9dc3
7 changed files with 53 additions and 128 deletions

View File

@ -174,12 +174,20 @@ inline DispatchKey get_autocast_dispatch_key_from_device_type(
}
}
inline at::ScalarType get_lower_precision_fp_from_device_type(
c10::DeviceType device_type) {
inline bool is_autocast_available(c10::DeviceType device_type) {
if (device_type == at::kCPU || device_type == at::kCUDA ||
device_type == at::kXPU || device_type == at::kIPU ||
device_type == at::kHPU || device_type == at::kXLA ||
device_type == at::kPrivateUse1) {
return true;
} else {
return false;
}
}
inline at::ScalarType get_lower_precision_fp_from_device_type(
c10::DeviceType device_type) {
if (is_autocast_available(device_type)) {
return get_autocast_dtype(device_type);
} else {
throw std::runtime_error(

View File

@ -336,7 +336,7 @@ class TestTorchAutocast(TestCase):
def test_invalid_device(self):
dev = "not a real device"
msg = f"unsupported autocast device_type '{dev}'"
msg = f"Invalid device string: '{dev}'"
with self.assertRaisesRegex(RuntimeError, msg):
with torch.autocast(device_type=dev):
_ = torch.tensor(1)

View File

@ -1301,6 +1301,7 @@ def clear_autocast_cache() -> None: ...
def set_autocast_cpu_enabled(enabled: _bool) -> None: ...
def is_autocast_cpu_enabled() -> _bool: ...
def _is_any_autocast_enabled() -> _bool: ...
def _is_autocast_available(device_type: str) -> _bool: ...
def set_autocast_cpu_dtype(dtype: _dtype) -> None: ...
def set_autocast_gpu_dtype(dtype: _dtype) -> None: ...
def get_autocast_cpu_dtype() -> _dtype: ...

View File

@ -199,35 +199,20 @@ class autocast:
assert dtype is not None
return
self.device = device_type
if not torch._C._is_autocast_available(self.device):
raise RuntimeError(
f"User specified an unsupported autocast device_type '{self.device}'"
)
self.custom_backend_name = torch._C._get_privateuse1_backend_name()
if self.device == "cuda":
self.fast_dtype = torch.get_autocast_gpu_dtype()
elif self.device == "cpu":
self.fast_dtype = torch.get_autocast_cpu_dtype()
elif self.device == "xpu":
self.fast_dtype = torch.xpu.get_autocast_xpu_dtype() # type: ignore[attr-defined]
elif self.device == "ipu":
self.fast_dtype = torch.get_autocast_ipu_dtype() # type: ignore[attr-defined]
elif self.device == "hpu":
self.fast_dtype = torch.hpu.get_autocast_hpu_dtype() # type: ignore[attr-defined]
elif self.device == "xla":
self.fast_dtype = torch.get_autocast_xla_dtype() # type: ignore[attr-defined]
elif self.device == self.custom_backend_name:
self.fast_dtype = torch.get_autocast_dtype(self.device)
if self.device == self.custom_backend_name:
necessary_funcs = [
"is_autocast_enabled",
"set_autocast_enabled",
"get_autocast_dtype",
"set_autocast_dtype",
"get_amp_supported_dtype",
]
message = f"Tried to use AMP with the `{self.custom_backend_name}` backend, but the backend has not "
message += "registered a module or the module miss some necessary funcs. The backend should register "
message += "a module by `torch._register_device_module`, and the module must have these funcs: \n"
message += "`is_autocast_enabled() -> bool`, `set_autocast_enabled(bool) -> None`, "
message += "`get_autocast_dtype() -> torch.dtype`, `set_autocast_dtype(torch.dtype) "
message += (
"-> None` and `get_amp_supported_dtype() -> List[torch.dtype]`. \n"
)
message += "`get_amp_supported_dtype() -> List[torch.dtype]`. \n"
assert hasattr(torch, self.custom_backend_name), message
self.custom_device_mod = getattr(torch, self.custom_backend_name)
@ -236,11 +221,6 @@ class autocast:
message + f"But the func `{func}` is missing. \n"
)
self.fast_dtype = self.custom_device_mod.get_autocast_dtype()
else:
raise RuntimeError(
f"User specified an unsupported autocast device_type '{self.device}'"
)
self._cache_enabled = torch.is_autocast_cache_enabled()
if (
enabled
@ -323,48 +303,11 @@ class autocast:
return self
self.prev_cache_enabled = torch.is_autocast_cache_enabled()
if self.device == "cpu":
self.prev = torch.is_autocast_cpu_enabled()
self.prev_fastdtype = torch.get_autocast_cpu_dtype()
torch.set_autocast_cpu_enabled(self._enabled)
torch.set_autocast_cpu_dtype(self.fast_dtype) # type: ignore[arg-type]
torch.autocast_increment_nesting()
elif self.device == "xpu":
self.prev = torch.xpu.is_autocast_xpu_enabled() # type: ignore[attr-defined]
self.prev_fastdtype = torch.xpu.get_autocast_xpu_dtype() # type: ignore[attr-defined]
torch.xpu.set_autocast_xpu_enabled(self._enabled) # type: ignore[attr-defined]
torch.xpu.set_autocast_xpu_dtype(self.fast_dtype) # type: ignore[attr-defined]
torch.autocast_increment_nesting()
elif self.device == "ipu":
self.prev = torch.is_autocast_ipu_enabled() # type: ignore[attr-defined]
self.prev_fastdtype = torch.get_autocast_ipu_dtype() # type: ignore[attr-defined]
torch.set_autocast_ipu_enabled(self._enabled) # type: ignore[attr-defined]
torch.set_autocast_ipu_dtype(self.fast_dtype) # type: ignore[attr-defined]
torch.autocast_increment_nesting()
elif self.device == "hpu":
self.prev = torch.hpu.is_autocast_hpu_enabled() # type: ignore[attr-defined]
self.prev_fastdtype = torch.hpu.get_autocast_hpu_dtype() # type: ignore[attr-defined]
torch.hpu.set_autocast_hpu_enabled(self._enabled) # type: ignore[attr-defined]
torch.hpu.set_autocast_hpu_dtype(self.fast_dtype) # type: ignore[attr-defined]
torch.autocast_increment_nesting()
elif self.device == "xla":
self.prev = torch.is_autocast_xla_enabled() # type: ignore[attr-defined]
self.prev_fastdtype = torch.get_autocast_xla_dtype() # type: ignore[attr-defined]
torch.set_autocast_xla_enabled(self._enabled) # type: ignore[attr-defined]
torch.set_autocast_xla_dtype(self.fast_dtype) # type: ignore[attr-defined]
torch.autocast_increment_nesting()
elif self.device == self.custom_backend_name:
self.prev = self.custom_device_mod.is_autocast_enabled()
self.prev_fastdtype = self.custom_device_mod.get_autocast_dtype()
self.custom_device_mod.set_autocast_enabled(self._enabled)
self.custom_device_mod.set_autocast_dtype(self.fast_dtype)
torch.autocast_increment_nesting()
else:
self.prev = torch.is_autocast_enabled()
self.prev_fastdtype = torch.get_autocast_gpu_dtype()
torch.set_autocast_gpu_dtype(self.fast_dtype) # type: ignore[arg-type]
torch.set_autocast_enabled(self._enabled)
torch.autocast_increment_nesting()
self.prev = torch.is_autocast_enabled(self.device)
self.prev_fastdtype = torch.get_autocast_dtype(self.device)
torch.set_autocast_enabled(self.device, self._enabled)
torch.set_autocast_dtype(self.device, self.fast_dtype) # type: ignore[arg-type]
torch.autocast_increment_nesting()
torch.set_autocast_cache_enabled(self._cache_enabled)
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): # type: ignore[override]
@ -372,41 +315,10 @@ class autocast:
return
# Drop the cache when we exit to a nesting level that's outside any instance of autocast.
if self.device == "cpu":
if torch.autocast_decrement_nesting() == 0:
torch.clear_autocast_cache()
torch.set_autocast_cpu_enabled(self.prev)
torch.set_autocast_cpu_dtype(self.prev_fastdtype)
elif self.device == "xpu":
if torch.autocast_decrement_nesting() == 0:
torch.clear_autocast_cache()
torch.xpu.set_autocast_xpu_enabled(self.prev) # type: ignore[attr-defined]
torch.xpu.set_autocast_xpu_dtype(self.prev_fastdtype) # type: ignore[attr-defined]
elif self.device == "ipu":
if torch.autocast_decrement_nesting() == 0:
torch.clear_autocast_cache()
torch.set_autocast_ipu_enabled(self.prev) # type: ignore[attr-defined]
torch.set_autocast_ipu_dtype(self.prev_fastdtype) # type: ignore[attr-defined]
elif self.device == "hpu":
if torch.autocast_decrement_nesting() == 0:
torch.clear_autocast_cache()
torch.hpu.set_autocast_hpu_enabled(self.prev) # type: ignore[attr-defined]
torch.hpu.set_autocast_hpu_dtype(self.prev_fastdtype) # type: ignore[attr-defined]
elif self.device == "xla":
if torch.autocast_decrement_nesting() == 0:
torch.clear_autocast_cache()
torch.set_autocast_xla_enabled(self.prev) # type: ignore[attr-defined]
torch.set_autocast_xla_dtype(self.prev_fastdtype) # type: ignore[attr-defined]
elif self.device == self.custom_backend_name:
if torch.autocast_decrement_nesting() == 0:
torch.clear_autocast_cache()
self.custom_device_mod.set_autocast_enabled(self.prev)
self.custom_device_mod.set_autocast_dtype(self.prev_fastdtype)
else:
if torch.autocast_decrement_nesting() == 0:
torch.clear_autocast_cache()
torch.set_autocast_enabled(self.prev)
torch.set_autocast_gpu_dtype(self.prev_fastdtype)
if torch.autocast_decrement_nesting() == 0:
torch.clear_autocast_cache()
torch.set_autocast_enabled(self.device, self.prev)
torch.set_autocast_dtype(self.device, self.prev_fastdtype)
torch.set_autocast_cache_enabled(self.prev_cache_enabled)
return False

View File

@ -574,6 +574,24 @@ static PyObject* is_any_autocast_enabled(PyObject* _unused, PyObject* arg) {
END_HANDLE_TH_ERRORS
}
static PyObject* is_autocast_available(
PyObject* _unused,
PyObject* args,
PyObject* kwargs) {
HANDLE_TH_ERRORS
static PythonArgParser parser(
{"_is_autocast_available(c10::string_view device_type)"});
ParsedArgs<1> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
auto device_type = at::Device(r.string(0)).type();
if (at::autocast::is_autocast_available(device_type)) {
Py_RETURN_TRUE;
} else {
Py_RETURN_FALSE;
}
END_HANDLE_TH_ERRORS
}
static PyObject* set_autocast_cpu_enabled(PyObject* _unused, PyObject* arg) {
HANDLE_TH_ERRORS
TORCH_CHECK_TYPE(
@ -1235,6 +1253,10 @@ static PyMethodDef methods[] = { // NOLINT
METH_VARARGS | METH_KEYWORDS,
nullptr},
{"_is_any_autocast_enabled", is_any_autocast_enabled, METH_NOARGS, nullptr},
{"_is_autocast_available",
castPyCFunctionWithKeywords(is_autocast_available),
METH_VARARGS | METH_KEYWORDS,
nullptr},
{"clear_autocast_cache", clear_autocast_cache, METH_NOARGS, nullptr},
{"set_autocast_cpu_enabled", set_autocast_cpu_enabled, METH_O, nullptr},
{"is_autocast_cpu_enabled", is_autocast_cpu_enabled, METH_NOARGS, nullptr},

View File

@ -36,20 +36,6 @@ def rename_privateuse1_backend(backend_name: str) -> None:
(1) ``get_amp_supported_dtype() -> List[torch.dtype]``
get the supported dtypes on your "foo" device in AMP, maybe the "foo" device supports one more dtype.
(2) ``is_autocast_enabled() -> bool``
check the AMP is enabled or not on your "foo" device.
(3) ``get_autocast_dtype() -> torch.dtype``
get the supported dtype on your "foo" device in AMP, which is set by ``set_autocast_dtype`` or the
default dtype, and the default dtype is ``torch.float16``.
(4) ``set_autocast_enabled(bool) -> None``
enable the AMP or not on your "foo" device.
(5) ``set_autocast_dtype(dtype) -> None``
set the supported dtype on your "foo" device in AMP, and the dtype be contained in the dtypes got
from ``get_amp_supported_dtype``.
Note(random): If you want to support to set seed for your device, BackendModule needs to have the following API's:
(1) ``_is_in_bad_fork() -> bool``

View File

@ -194,7 +194,7 @@ def set_device_states(devices, states) -> None:
def _get_autocast_kwargs(device="cuda"):
if _supports_autocast(device):
if torch._C._is_autocast_available(device):
device_autocast_kwargs = {
"enabled": torch.is_autocast_enabled(device),
"dtype": torch.get_autocast_dtype(device),
@ -211,10 +211,6 @@ def _get_autocast_kwargs(device="cuda"):
return device_autocast_kwargs, cpu_autocast_kwargs
def _supports_autocast(device):
device_module = _get_device_module(device)
return device == "cuda" or (hasattr(device_module, "is_autocast_enabled")
and hasattr(device_module, "get_autocast_dtype"))
class CheckpointFunction(torch.autograd.Function):
@staticmethod
@ -293,7 +289,7 @@ class CheckpointFunction(torch.autograd.Function):
device_autocast_ctx = device_module.amp.autocast(
**ctx.device_autocast_kwargs
) if _supports_autocast(ctx.device) else contextlib.nullcontext()
) if torch._C._is_autocast_available(ctx.device) else contextlib.nullcontext()
with torch.enable_grad(), device_autocast_ctx, \
torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):
outputs = ctx.run_function(*detached_inputs)
@ -1400,7 +1396,7 @@ def _checkpoint_without_reentrant_generator(
device_autocast_ctx = device_module.amp.autocast(
**device_autocast_kwargs
) if _supports_autocast(device) else contextlib.nullcontext()
) if torch._C._is_autocast_available(device) else contextlib.nullcontext()
with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), \
recompute_context:
fn(*args, **kwargs)