mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
refactor autocast python APIs (#124479)
# Motivation Refactor autocast usage scenario in `torch/amp/autocast_mode.py` and `torch/utils/checkpoint.py` to fix the bug - convention conflict between `torch.xxx.get_autocast_xxx_dtype` defined in `autocast_mode.py` and `torch.xxx.get_autocast_dtype` defined in `checkpoint.py`. # Solution Use device-agnostic APIs like `torch.get_autocast_dtype`, ..., instead. Pull Request resolved: https://github.com/pytorch/pytorch/pull/124479 Approved by: https://github.com/jgong5, https://github.com/gujinghui, https://github.com/EikanWang, https://github.com/albanD ghstack dependencies: #124359
This commit is contained in:
committed by
PyTorch MergeBot
parent
f01275934b
commit
cdc66e9dc3
@ -174,12 +174,20 @@ inline DispatchKey get_autocast_dispatch_key_from_device_type(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inline at::ScalarType get_lower_precision_fp_from_device_type(
|
inline bool is_autocast_available(c10::DeviceType device_type) {
|
||||||
c10::DeviceType device_type) {
|
|
||||||
if (device_type == at::kCPU || device_type == at::kCUDA ||
|
if (device_type == at::kCPU || device_type == at::kCUDA ||
|
||||||
device_type == at::kXPU || device_type == at::kIPU ||
|
device_type == at::kXPU || device_type == at::kIPU ||
|
||||||
device_type == at::kHPU || device_type == at::kXLA ||
|
device_type == at::kHPU || device_type == at::kXLA ||
|
||||||
device_type == at::kPrivateUse1) {
|
device_type == at::kPrivateUse1) {
|
||||||
|
return true;
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline at::ScalarType get_lower_precision_fp_from_device_type(
|
||||||
|
c10::DeviceType device_type) {
|
||||||
|
if (is_autocast_available(device_type)) {
|
||||||
return get_autocast_dtype(device_type);
|
return get_autocast_dtype(device_type);
|
||||||
} else {
|
} else {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
|
@ -336,7 +336,7 @@ class TestTorchAutocast(TestCase):
|
|||||||
|
|
||||||
def test_invalid_device(self):
|
def test_invalid_device(self):
|
||||||
dev = "not a real device"
|
dev = "not a real device"
|
||||||
msg = f"unsupported autocast device_type '{dev}'"
|
msg = f"Invalid device string: '{dev}'"
|
||||||
with self.assertRaisesRegex(RuntimeError, msg):
|
with self.assertRaisesRegex(RuntimeError, msg):
|
||||||
with torch.autocast(device_type=dev):
|
with torch.autocast(device_type=dev):
|
||||||
_ = torch.tensor(1)
|
_ = torch.tensor(1)
|
||||||
|
@ -1301,6 +1301,7 @@ def clear_autocast_cache() -> None: ...
|
|||||||
def set_autocast_cpu_enabled(enabled: _bool) -> None: ...
|
def set_autocast_cpu_enabled(enabled: _bool) -> None: ...
|
||||||
def is_autocast_cpu_enabled() -> _bool: ...
|
def is_autocast_cpu_enabled() -> _bool: ...
|
||||||
def _is_any_autocast_enabled() -> _bool: ...
|
def _is_any_autocast_enabled() -> _bool: ...
|
||||||
|
def _is_autocast_available(device_type: str) -> _bool: ...
|
||||||
def set_autocast_cpu_dtype(dtype: _dtype) -> None: ...
|
def set_autocast_cpu_dtype(dtype: _dtype) -> None: ...
|
||||||
def set_autocast_gpu_dtype(dtype: _dtype) -> None: ...
|
def set_autocast_gpu_dtype(dtype: _dtype) -> None: ...
|
||||||
def get_autocast_cpu_dtype() -> _dtype: ...
|
def get_autocast_cpu_dtype() -> _dtype: ...
|
||||||
|
@ -199,35 +199,20 @@ class autocast:
|
|||||||
assert dtype is not None
|
assert dtype is not None
|
||||||
return
|
return
|
||||||
self.device = device_type
|
self.device = device_type
|
||||||
|
if not torch._C._is_autocast_available(self.device):
|
||||||
|
raise RuntimeError(
|
||||||
|
f"User specified an unsupported autocast device_type '{self.device}'"
|
||||||
|
)
|
||||||
self.custom_backend_name = torch._C._get_privateuse1_backend_name()
|
self.custom_backend_name = torch._C._get_privateuse1_backend_name()
|
||||||
if self.device == "cuda":
|
self.fast_dtype = torch.get_autocast_dtype(self.device)
|
||||||
self.fast_dtype = torch.get_autocast_gpu_dtype()
|
if self.device == self.custom_backend_name:
|
||||||
elif self.device == "cpu":
|
|
||||||
self.fast_dtype = torch.get_autocast_cpu_dtype()
|
|
||||||
elif self.device == "xpu":
|
|
||||||
self.fast_dtype = torch.xpu.get_autocast_xpu_dtype() # type: ignore[attr-defined]
|
|
||||||
elif self.device == "ipu":
|
|
||||||
self.fast_dtype = torch.get_autocast_ipu_dtype() # type: ignore[attr-defined]
|
|
||||||
elif self.device == "hpu":
|
|
||||||
self.fast_dtype = torch.hpu.get_autocast_hpu_dtype() # type: ignore[attr-defined]
|
|
||||||
elif self.device == "xla":
|
|
||||||
self.fast_dtype = torch.get_autocast_xla_dtype() # type: ignore[attr-defined]
|
|
||||||
elif self.device == self.custom_backend_name:
|
|
||||||
necessary_funcs = [
|
necessary_funcs = [
|
||||||
"is_autocast_enabled",
|
|
||||||
"set_autocast_enabled",
|
|
||||||
"get_autocast_dtype",
|
|
||||||
"set_autocast_dtype",
|
|
||||||
"get_amp_supported_dtype",
|
"get_amp_supported_dtype",
|
||||||
]
|
]
|
||||||
message = f"Tried to use AMP with the `{self.custom_backend_name}` backend, but the backend has not "
|
message = f"Tried to use AMP with the `{self.custom_backend_name}` backend, but the backend has not "
|
||||||
message += "registered a module or the module miss some necessary funcs. The backend should register "
|
message += "registered a module or the module miss some necessary funcs. The backend should register "
|
||||||
message += "a module by `torch._register_device_module`, and the module must have these funcs: \n"
|
message += "a module by `torch._register_device_module`, and the module must have these funcs: \n"
|
||||||
message += "`is_autocast_enabled() -> bool`, `set_autocast_enabled(bool) -> None`, "
|
message += "`get_amp_supported_dtype() -> List[torch.dtype]`. \n"
|
||||||
message += "`get_autocast_dtype() -> torch.dtype`, `set_autocast_dtype(torch.dtype) "
|
|
||||||
message += (
|
|
||||||
"-> None` and `get_amp_supported_dtype() -> List[torch.dtype]`. \n"
|
|
||||||
)
|
|
||||||
|
|
||||||
assert hasattr(torch, self.custom_backend_name), message
|
assert hasattr(torch, self.custom_backend_name), message
|
||||||
self.custom_device_mod = getattr(torch, self.custom_backend_name)
|
self.custom_device_mod = getattr(torch, self.custom_backend_name)
|
||||||
@ -236,11 +221,6 @@ class autocast:
|
|||||||
message + f"But the func `{func}` is missing. \n"
|
message + f"But the func `{func}` is missing. \n"
|
||||||
)
|
)
|
||||||
|
|
||||||
self.fast_dtype = self.custom_device_mod.get_autocast_dtype()
|
|
||||||
else:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"User specified an unsupported autocast device_type '{self.device}'"
|
|
||||||
)
|
|
||||||
self._cache_enabled = torch.is_autocast_cache_enabled()
|
self._cache_enabled = torch.is_autocast_cache_enabled()
|
||||||
if (
|
if (
|
||||||
enabled
|
enabled
|
||||||
@ -323,48 +303,11 @@ class autocast:
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
self.prev_cache_enabled = torch.is_autocast_cache_enabled()
|
self.prev_cache_enabled = torch.is_autocast_cache_enabled()
|
||||||
if self.device == "cpu":
|
self.prev = torch.is_autocast_enabled(self.device)
|
||||||
self.prev = torch.is_autocast_cpu_enabled()
|
self.prev_fastdtype = torch.get_autocast_dtype(self.device)
|
||||||
self.prev_fastdtype = torch.get_autocast_cpu_dtype()
|
torch.set_autocast_enabled(self.device, self._enabled)
|
||||||
torch.set_autocast_cpu_enabled(self._enabled)
|
torch.set_autocast_dtype(self.device, self.fast_dtype) # type: ignore[arg-type]
|
||||||
torch.set_autocast_cpu_dtype(self.fast_dtype) # type: ignore[arg-type]
|
torch.autocast_increment_nesting()
|
||||||
torch.autocast_increment_nesting()
|
|
||||||
elif self.device == "xpu":
|
|
||||||
self.prev = torch.xpu.is_autocast_xpu_enabled() # type: ignore[attr-defined]
|
|
||||||
self.prev_fastdtype = torch.xpu.get_autocast_xpu_dtype() # type: ignore[attr-defined]
|
|
||||||
torch.xpu.set_autocast_xpu_enabled(self._enabled) # type: ignore[attr-defined]
|
|
||||||
torch.xpu.set_autocast_xpu_dtype(self.fast_dtype) # type: ignore[attr-defined]
|
|
||||||
torch.autocast_increment_nesting()
|
|
||||||
elif self.device == "ipu":
|
|
||||||
self.prev = torch.is_autocast_ipu_enabled() # type: ignore[attr-defined]
|
|
||||||
self.prev_fastdtype = torch.get_autocast_ipu_dtype() # type: ignore[attr-defined]
|
|
||||||
torch.set_autocast_ipu_enabled(self._enabled) # type: ignore[attr-defined]
|
|
||||||
torch.set_autocast_ipu_dtype(self.fast_dtype) # type: ignore[attr-defined]
|
|
||||||
torch.autocast_increment_nesting()
|
|
||||||
elif self.device == "hpu":
|
|
||||||
self.prev = torch.hpu.is_autocast_hpu_enabled() # type: ignore[attr-defined]
|
|
||||||
self.prev_fastdtype = torch.hpu.get_autocast_hpu_dtype() # type: ignore[attr-defined]
|
|
||||||
torch.hpu.set_autocast_hpu_enabled(self._enabled) # type: ignore[attr-defined]
|
|
||||||
torch.hpu.set_autocast_hpu_dtype(self.fast_dtype) # type: ignore[attr-defined]
|
|
||||||
torch.autocast_increment_nesting()
|
|
||||||
elif self.device == "xla":
|
|
||||||
self.prev = torch.is_autocast_xla_enabled() # type: ignore[attr-defined]
|
|
||||||
self.prev_fastdtype = torch.get_autocast_xla_dtype() # type: ignore[attr-defined]
|
|
||||||
torch.set_autocast_xla_enabled(self._enabled) # type: ignore[attr-defined]
|
|
||||||
torch.set_autocast_xla_dtype(self.fast_dtype) # type: ignore[attr-defined]
|
|
||||||
torch.autocast_increment_nesting()
|
|
||||||
elif self.device == self.custom_backend_name:
|
|
||||||
self.prev = self.custom_device_mod.is_autocast_enabled()
|
|
||||||
self.prev_fastdtype = self.custom_device_mod.get_autocast_dtype()
|
|
||||||
self.custom_device_mod.set_autocast_enabled(self._enabled)
|
|
||||||
self.custom_device_mod.set_autocast_dtype(self.fast_dtype)
|
|
||||||
torch.autocast_increment_nesting()
|
|
||||||
else:
|
|
||||||
self.prev = torch.is_autocast_enabled()
|
|
||||||
self.prev_fastdtype = torch.get_autocast_gpu_dtype()
|
|
||||||
torch.set_autocast_gpu_dtype(self.fast_dtype) # type: ignore[arg-type]
|
|
||||||
torch.set_autocast_enabled(self._enabled)
|
|
||||||
torch.autocast_increment_nesting()
|
|
||||||
torch.set_autocast_cache_enabled(self._cache_enabled)
|
torch.set_autocast_cache_enabled(self._cache_enabled)
|
||||||
|
|
||||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): # type: ignore[override]
|
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): # type: ignore[override]
|
||||||
@ -372,41 +315,10 @@ class autocast:
|
|||||||
return
|
return
|
||||||
|
|
||||||
# Drop the cache when we exit to a nesting level that's outside any instance of autocast.
|
# Drop the cache when we exit to a nesting level that's outside any instance of autocast.
|
||||||
if self.device == "cpu":
|
if torch.autocast_decrement_nesting() == 0:
|
||||||
if torch.autocast_decrement_nesting() == 0:
|
torch.clear_autocast_cache()
|
||||||
torch.clear_autocast_cache()
|
torch.set_autocast_enabled(self.device, self.prev)
|
||||||
torch.set_autocast_cpu_enabled(self.prev)
|
torch.set_autocast_dtype(self.device, self.prev_fastdtype)
|
||||||
torch.set_autocast_cpu_dtype(self.prev_fastdtype)
|
|
||||||
elif self.device == "xpu":
|
|
||||||
if torch.autocast_decrement_nesting() == 0:
|
|
||||||
torch.clear_autocast_cache()
|
|
||||||
torch.xpu.set_autocast_xpu_enabled(self.prev) # type: ignore[attr-defined]
|
|
||||||
torch.xpu.set_autocast_xpu_dtype(self.prev_fastdtype) # type: ignore[attr-defined]
|
|
||||||
elif self.device == "ipu":
|
|
||||||
if torch.autocast_decrement_nesting() == 0:
|
|
||||||
torch.clear_autocast_cache()
|
|
||||||
torch.set_autocast_ipu_enabled(self.prev) # type: ignore[attr-defined]
|
|
||||||
torch.set_autocast_ipu_dtype(self.prev_fastdtype) # type: ignore[attr-defined]
|
|
||||||
elif self.device == "hpu":
|
|
||||||
if torch.autocast_decrement_nesting() == 0:
|
|
||||||
torch.clear_autocast_cache()
|
|
||||||
torch.hpu.set_autocast_hpu_enabled(self.prev) # type: ignore[attr-defined]
|
|
||||||
torch.hpu.set_autocast_hpu_dtype(self.prev_fastdtype) # type: ignore[attr-defined]
|
|
||||||
elif self.device == "xla":
|
|
||||||
if torch.autocast_decrement_nesting() == 0:
|
|
||||||
torch.clear_autocast_cache()
|
|
||||||
torch.set_autocast_xla_enabled(self.prev) # type: ignore[attr-defined]
|
|
||||||
torch.set_autocast_xla_dtype(self.prev_fastdtype) # type: ignore[attr-defined]
|
|
||||||
elif self.device == self.custom_backend_name:
|
|
||||||
if torch.autocast_decrement_nesting() == 0:
|
|
||||||
torch.clear_autocast_cache()
|
|
||||||
self.custom_device_mod.set_autocast_enabled(self.prev)
|
|
||||||
self.custom_device_mod.set_autocast_dtype(self.prev_fastdtype)
|
|
||||||
else:
|
|
||||||
if torch.autocast_decrement_nesting() == 0:
|
|
||||||
torch.clear_autocast_cache()
|
|
||||||
torch.set_autocast_enabled(self.prev)
|
|
||||||
torch.set_autocast_gpu_dtype(self.prev_fastdtype)
|
|
||||||
torch.set_autocast_cache_enabled(self.prev_cache_enabled)
|
torch.set_autocast_cache_enabled(self.prev_cache_enabled)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -574,6 +574,24 @@ static PyObject* is_any_autocast_enabled(PyObject* _unused, PyObject* arg) {
|
|||||||
END_HANDLE_TH_ERRORS
|
END_HANDLE_TH_ERRORS
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static PyObject* is_autocast_available(
|
||||||
|
PyObject* _unused,
|
||||||
|
PyObject* args,
|
||||||
|
PyObject* kwargs) {
|
||||||
|
HANDLE_TH_ERRORS
|
||||||
|
static PythonArgParser parser(
|
||||||
|
{"_is_autocast_available(c10::string_view device_type)"});
|
||||||
|
ParsedArgs<1> parsed_args;
|
||||||
|
auto r = parser.parse(args, kwargs, parsed_args);
|
||||||
|
auto device_type = at::Device(r.string(0)).type();
|
||||||
|
if (at::autocast::is_autocast_available(device_type)) {
|
||||||
|
Py_RETURN_TRUE;
|
||||||
|
} else {
|
||||||
|
Py_RETURN_FALSE;
|
||||||
|
}
|
||||||
|
END_HANDLE_TH_ERRORS
|
||||||
|
}
|
||||||
|
|
||||||
static PyObject* set_autocast_cpu_enabled(PyObject* _unused, PyObject* arg) {
|
static PyObject* set_autocast_cpu_enabled(PyObject* _unused, PyObject* arg) {
|
||||||
HANDLE_TH_ERRORS
|
HANDLE_TH_ERRORS
|
||||||
TORCH_CHECK_TYPE(
|
TORCH_CHECK_TYPE(
|
||||||
@ -1235,6 +1253,10 @@ static PyMethodDef methods[] = { // NOLINT
|
|||||||
METH_VARARGS | METH_KEYWORDS,
|
METH_VARARGS | METH_KEYWORDS,
|
||||||
nullptr},
|
nullptr},
|
||||||
{"_is_any_autocast_enabled", is_any_autocast_enabled, METH_NOARGS, nullptr},
|
{"_is_any_autocast_enabled", is_any_autocast_enabled, METH_NOARGS, nullptr},
|
||||||
|
{"_is_autocast_available",
|
||||||
|
castPyCFunctionWithKeywords(is_autocast_available),
|
||||||
|
METH_VARARGS | METH_KEYWORDS,
|
||||||
|
nullptr},
|
||||||
{"clear_autocast_cache", clear_autocast_cache, METH_NOARGS, nullptr},
|
{"clear_autocast_cache", clear_autocast_cache, METH_NOARGS, nullptr},
|
||||||
{"set_autocast_cpu_enabled", set_autocast_cpu_enabled, METH_O, nullptr},
|
{"set_autocast_cpu_enabled", set_autocast_cpu_enabled, METH_O, nullptr},
|
||||||
{"is_autocast_cpu_enabled", is_autocast_cpu_enabled, METH_NOARGS, nullptr},
|
{"is_autocast_cpu_enabled", is_autocast_cpu_enabled, METH_NOARGS, nullptr},
|
||||||
|
@ -36,20 +36,6 @@ def rename_privateuse1_backend(backend_name: str) -> None:
|
|||||||
(1) ``get_amp_supported_dtype() -> List[torch.dtype]``
|
(1) ``get_amp_supported_dtype() -> List[torch.dtype]``
|
||||||
get the supported dtypes on your "foo" device in AMP, maybe the "foo" device supports one more dtype.
|
get the supported dtypes on your "foo" device in AMP, maybe the "foo" device supports one more dtype.
|
||||||
|
|
||||||
(2) ``is_autocast_enabled() -> bool``
|
|
||||||
check the AMP is enabled or not on your "foo" device.
|
|
||||||
|
|
||||||
(3) ``get_autocast_dtype() -> torch.dtype``
|
|
||||||
get the supported dtype on your "foo" device in AMP, which is set by ``set_autocast_dtype`` or the
|
|
||||||
default dtype, and the default dtype is ``torch.float16``.
|
|
||||||
|
|
||||||
(4) ``set_autocast_enabled(bool) -> None``
|
|
||||||
enable the AMP or not on your "foo" device.
|
|
||||||
|
|
||||||
(5) ``set_autocast_dtype(dtype) -> None``
|
|
||||||
set the supported dtype on your "foo" device in AMP, and the dtype be contained in the dtypes got
|
|
||||||
from ``get_amp_supported_dtype``.
|
|
||||||
|
|
||||||
Note(random): If you want to support to set seed for your device, BackendModule needs to have the following API's:
|
Note(random): If you want to support to set seed for your device, BackendModule needs to have the following API's:
|
||||||
|
|
||||||
(1) ``_is_in_bad_fork() -> bool``
|
(1) ``_is_in_bad_fork() -> bool``
|
||||||
|
@ -194,7 +194,7 @@ def set_device_states(devices, states) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def _get_autocast_kwargs(device="cuda"):
|
def _get_autocast_kwargs(device="cuda"):
|
||||||
if _supports_autocast(device):
|
if torch._C._is_autocast_available(device):
|
||||||
device_autocast_kwargs = {
|
device_autocast_kwargs = {
|
||||||
"enabled": torch.is_autocast_enabled(device),
|
"enabled": torch.is_autocast_enabled(device),
|
||||||
"dtype": torch.get_autocast_dtype(device),
|
"dtype": torch.get_autocast_dtype(device),
|
||||||
@ -211,10 +211,6 @@ def _get_autocast_kwargs(device="cuda"):
|
|||||||
|
|
||||||
return device_autocast_kwargs, cpu_autocast_kwargs
|
return device_autocast_kwargs, cpu_autocast_kwargs
|
||||||
|
|
||||||
def _supports_autocast(device):
|
|
||||||
device_module = _get_device_module(device)
|
|
||||||
return device == "cuda" or (hasattr(device_module, "is_autocast_enabled")
|
|
||||||
and hasattr(device_module, "get_autocast_dtype"))
|
|
||||||
|
|
||||||
class CheckpointFunction(torch.autograd.Function):
|
class CheckpointFunction(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -293,7 +289,7 @@ class CheckpointFunction(torch.autograd.Function):
|
|||||||
|
|
||||||
device_autocast_ctx = device_module.amp.autocast(
|
device_autocast_ctx = device_module.amp.autocast(
|
||||||
**ctx.device_autocast_kwargs
|
**ctx.device_autocast_kwargs
|
||||||
) if _supports_autocast(ctx.device) else contextlib.nullcontext()
|
) if torch._C._is_autocast_available(ctx.device) else contextlib.nullcontext()
|
||||||
with torch.enable_grad(), device_autocast_ctx, \
|
with torch.enable_grad(), device_autocast_ctx, \
|
||||||
torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):
|
torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):
|
||||||
outputs = ctx.run_function(*detached_inputs)
|
outputs = ctx.run_function(*detached_inputs)
|
||||||
@ -1400,7 +1396,7 @@ def _checkpoint_without_reentrant_generator(
|
|||||||
|
|
||||||
device_autocast_ctx = device_module.amp.autocast(
|
device_autocast_ctx = device_module.amp.autocast(
|
||||||
**device_autocast_kwargs
|
**device_autocast_kwargs
|
||||||
) if _supports_autocast(device) else contextlib.nullcontext()
|
) if torch._C._is_autocast_available(device) else contextlib.nullcontext()
|
||||||
with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), \
|
with device_autocast_ctx, torch.cpu.amp.autocast(**cpu_autocast_kwargs), \
|
||||||
recompute_context:
|
recompute_context:
|
||||||
fn(*args, **kwargs)
|
fn(*args, **kwargs)
|
||||||
|
Reference in New Issue
Block a user