Refactor autocast C++ APIs to be device-agnostic (#124359)

# Motivation
This PR aims to refactor autocast **C++** APIs to be device-agnostic and deprecate the device-specific autocast  **C++** APIs.
In C++ side,
- `is_enabled()` -> `is_enabled(device_type)`.
- `set_enabled(new_enabled)` -> `set_enabled(device_type, new_enabled)`.
- `get_autocast_dtype()` -> `get_autocast_dtype(device_type)`
- `set_autocast_dtype(dtype)` -> `set_autocast_dtype(device_type, dtype)`

These following C++ APIs are deprecated and should be removed in PyTorch 2.5
- `is_cpu_enabled`
- `set_cpu_enabled`
- `get_autocast_cpu_dtype`
- `set_autocast_cpu_dtype`
- `is_xpu_enabled`
- `set_xpu_enabled`
- `get_autocast_xpu_dtype`
- `set_autocast_xpu_dtype`
- `is_ipu_enabled`
- `set_ipu_enabled`
- `get_autocast_ipu_dtype`
- `set_autocast_ipu_dtype`
- `is_hpu_enabled`
- `set_hpu_enabled`
- `get_autocast_hpu_dtype`
- `set_autocast_hpu_dtype`
- `is_xla_enabled`
- `set_xla_enabled`
- `get_autocast_xla_dtype`
- `set_autocast_xla_dtype`
- `is_privateuseone_enabled`
- `set_privateuseone_enabled`
- `get_autocast_privateuseone_dtype`
- `set_autocast_privateuseone_dtype`

In Python side,
provide 4 generic autocast APIs:
- `torch.is_autocast_enabled(device_type)`
- `torch.set_autocast_enabled(device_type, new_enabled)`
- `torch.get_autocast_dtype(device_type)`
- `torch.set_autocast_dtype(device_type, dtype)`

# Additional Context
We will submit another PR to refactor autocast **Python** APIs based on this PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124359
Approved by: https://github.com/jgong5, https://github.com/albanD
This commit is contained in:
Yu, Guangye
2024-04-23 09:34:45 +00:00
committed by PyTorch MergeBot
parent 3c964ad1ca
commit 25f321b84f
12 changed files with 335 additions and 234 deletions

View File

@ -6,60 +6,14 @@
namespace at::autocast {
bool is_enabled() {
return !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::AutocastCUDA);
bool is_autocast_enabled(at::DeviceType device_type) {
at::DispatchKey dispatch_key = get_autocast_dispatch_key_from_device_type(device_type);
return !c10::impl::tls_is_dispatch_key_excluded(dispatch_key);
}
void set_enabled(bool new_enabled) {
c10::impl::tls_set_dispatch_key_excluded(DispatchKey::AutocastCUDA, !new_enabled);
}
bool is_cpu_enabled() {
return !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::AutocastCPU);
}
void set_cpu_enabled(bool new_enabled) {
c10::impl::tls_set_dispatch_key_excluded(DispatchKey::AutocastCPU, !new_enabled);
}
bool is_xpu_enabled() {
return !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::AutocastXPU);
}
void set_xpu_enabled(bool new_enabled) {
c10::impl::tls_set_dispatch_key_excluded(DispatchKey::AutocastXPU, !new_enabled);
}
bool is_ipu_enabled() {
return !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::AutocastIPU);
}
void set_ipu_enabled(bool new_enabled) {
c10::impl::tls_set_dispatch_key_excluded(DispatchKey::AutocastIPU, !new_enabled);
}
bool is_hpu_enabled() {
return !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::AutocastHPU);
}
void set_hpu_enabled(bool new_enabled) {
c10::impl::tls_set_dispatch_key_excluded(DispatchKey::AutocastHPU, !new_enabled);
}
bool is_xla_enabled() {
return !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::AutocastXLA);
}
void set_xla_enabled(bool new_enabled) {
c10::impl::tls_set_dispatch_key_excluded(DispatchKey::AutocastXLA, !new_enabled);
}
bool is_privateuseone_enabled() {
return !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::AutocastPrivateUse1);
}
void set_privateuseone_enabled(bool new_enabled) {
c10::impl::tls_set_dispatch_key_excluded(DispatchKey::AutocastPrivateUse1, !new_enabled);
void set_autocast_enabled(at::DeviceType device_type, bool enabled) {
at::DispatchKey dispatch_key = get_autocast_dispatch_key_from_device_type(device_type);
c10::impl::tls_set_dispatch_key_excluded(dispatch_key, !enabled);
}
namespace {
@ -91,30 +45,40 @@ std::mutex cached_casts_mutex;
// it calls clear_cache() to ensure cached Tensors don't leak outside the autocasting region.
thread_local int nesting = 0;
// autocast_cpu_dtype is the lower_precision_fp used by AutocastCPU.
thread_local at::ScalarType autocast_cpu_dtype = at::kBFloat16;
// autocast_xpu_dtype is the lower_precision_fp used by AutocastXPU.
thread_local at::ScalarType autocast_xpu_dtype = at::kBFloat16;
// autocast_ipu_dtype is the lower_precision_fp used by AutocastIPU.
thread_local at::ScalarType autocast_ipu_dtype = at::kHalf;
// autocast_hpu_dtype is the lower_precision_fp used by AutocastHPU.
thread_local at::ScalarType autocast_hpu_dtype = at::kBFloat16;
// autocast_xla_dtype is the lower_precision_fp used by AutocastXLA.
thread_local at::ScalarType autocast_xla_dtype = at::kBFloat16;
// The order of this array MUST exactly match the definition order of DeviceType
// in c10/core/DeviceType.h.
static_assert(
at::COMPILE_TIME_MAX_DEVICE_TYPES == 21,
"The definition of the default autocast data type per device backend doesn't match with the definition of the device type.");
thread_local std::array<at::ScalarType, at::COMPILE_TIME_MAX_DEVICE_TYPES>
autocast_dtype = {
at::kBFloat16, // CPU
at::kHalf, // CUDA.
at::ScalarType::Undefined, // Reserved for explicit MKLDNN
at::ScalarType::Undefined, // OpenGL
at::ScalarType::Undefined, // OpenCL
at::ScalarType::Undefined, // IDEEP.
at::kHalf, // AMD HIP
at::ScalarType::Undefined, // FPGA
at::ScalarType::Undefined, // ONNX Runtime / Microsoft
at::kBFloat16, // XLA / TPU
at::ScalarType::Undefined, // Vulkan
at::ScalarType::Undefined, // Metal
at::kBFloat16, // XPU
at::ScalarType::Undefined, // MPS
at::ScalarType::Undefined, // Meta (tensors with no data)
at::kBFloat16, // HPU / HABANA
at::ScalarType::Undefined, // SX-Aurora / NEC
at::ScalarType::Undefined, // Lazy Tensors
at::kHalf, // Graphcore IPU
at::ScalarType::Undefined, // Meta training and inference devices
at::kHalf, // PrivateUse1 device
};
// should we enabled the cache inside autocast.
thread_local bool cache_enabled = true;
// autocast_gpu_dtype is the lower_precision_fp used by AutocastGPU.
thread_local at::ScalarType autocast_gpu_dtype = at::kHalf;
// autocast_privateuseone_dtype is the lower_precision_fp used by AutocastPrivateUse1.
thread_local at::ScalarType autocast_privateuseone_dtype = at::kHalf;
}
} // anonymous namespace
void clear_cache() {
const std::lock_guard<std::mutex> lock(cached_casts_mutex);
@ -129,60 +93,12 @@ int decrement_nesting() {
return --nesting;
}
at::ScalarType get_autocast_gpu_dtype() {
return autocast_gpu_dtype;
at::ScalarType get_autocast_dtype(at::DeviceType device_type) {
return autocast_dtype[static_cast<int>(device_type)];
}
at::ScalarType get_autocast_cpu_dtype() {
return autocast_cpu_dtype;
}
at::ScalarType get_autocast_xpu_dtype() {
return autocast_xpu_dtype;
}
at::ScalarType get_autocast_ipu_dtype() {
return autocast_ipu_dtype;
}
at::ScalarType get_autocast_hpu_dtype() {
return autocast_hpu_dtype;
}
at::ScalarType get_autocast_xla_dtype() {
return autocast_xla_dtype;
}
at::ScalarType get_autocast_privateuseone_dtype() {
return autocast_privateuseone_dtype;
}
void set_autocast_cpu_dtype(at::ScalarType dtype) {
autocast_cpu_dtype = dtype;
}
void set_autocast_gpu_dtype(at::ScalarType dtype) {
autocast_gpu_dtype = dtype;
}
void set_autocast_xpu_dtype(at::ScalarType dtype) {
autocast_xpu_dtype = dtype;
}
void set_autocast_ipu_dtype(at::ScalarType dtype) {
autocast_ipu_dtype = dtype;
}
void set_autocast_hpu_dtype(at::ScalarType dtype) {
autocast_hpu_dtype = dtype;
}
void set_autocast_xla_dtype(at::ScalarType dtype) {
autocast_xla_dtype = dtype;
}
void set_autocast_privateuseone_dtype(at::ScalarType dtype) {
autocast_privateuseone_dtype = dtype;
void set_autocast_dtype(at::DeviceType device_type, at::ScalarType dtype) {
autocast_dtype[static_cast<int>(device_type)] = dtype;
}
bool is_autocast_cache_enabled() {

View File

@ -10,40 +10,120 @@
namespace at::autocast {
TORCH_API bool is_enabled();
TORCH_API void set_enabled(bool enabled);
TORCH_API bool is_autocast_enabled(at::DeviceType device_type);
TORCH_API void set_autocast_enabled(at::DeviceType device_type, bool enabled);
TORCH_API at::ScalarType get_autocast_dtype(at::DeviceType device_type);
TORCH_API void set_autocast_dtype(
at::DeviceType device_type,
at::ScalarType dtype);
TORCH_API void clear_cache();
TORCH_API int increment_nesting();
TORCH_API int decrement_nesting();
TORCH_API bool is_cpu_enabled();
TORCH_API void set_cpu_enabled(bool enabled);
TORCH_API at::ScalarType get_autocast_gpu_dtype();
TORCH_API at::ScalarType get_autocast_cpu_dtype();
TORCH_API void set_autocast_gpu_dtype(at::ScalarType dtype);
TORCH_API void set_autocast_cpu_dtype(at::ScalarType dtype);
TORCH_API bool is_xpu_enabled();
TORCH_API void set_xpu_enabled(bool enabled);
TORCH_API at::ScalarType get_autocast_xpu_dtype();
TORCH_API void set_autocast_xpu_dtype(at::ScalarType dtype);
TORCH_API bool is_ipu_enabled();
TORCH_API void set_ipu_enabled(bool enabled);
TORCH_API at::ScalarType get_autocast_ipu_dtype();
TORCH_API void set_autocast_ipu_dtype(at::ScalarType dtype);
TORCH_API bool is_hpu_enabled();
TORCH_API void set_hpu_enabled(bool enabled);
TORCH_API at::ScalarType get_autocast_hpu_dtype();
TORCH_API void set_autocast_hpu_dtype(at::ScalarType dtype);
TORCH_API bool is_xla_enabled();
TORCH_API void set_xla_enabled(bool enabled);
TORCH_API at::ScalarType get_autocast_xla_dtype();
TORCH_API void set_autocast_xla_dtype(at::ScalarType dtype);
TORCH_API bool is_privateuseone_enabled();
TORCH_API void set_privateuseone_enabled(bool enabled);
TORCH_API at::ScalarType get_autocast_privateuseone_dtype();
TORCH_API void set_autocast_privateuseone_dtype(at::ScalarType dtype);
TORCH_API bool is_autocast_cache_enabled();
TORCH_API void set_autocast_cache_enabled(bool enabled);
// deprecated CUDA-specific autocast APIs
C10_DEPRECATED_MESSAGE(
"at::autocast::is_enabled() is deprecated. Please use at::autocast::is_autocast_enabled(at::kCUDA) instead.")
TORCH_API inline bool is_enabled() {
TORCH_WARN_DEPRECATION(
"at::autocast::",
__func__,
"() is deprecated. Please use at::autocast::is_autocast_enabled(at::kCUDA) instead.")
return is_autocast_enabled(at::kCUDA);
}
C10_DEPRECATED_MESSAGE(
"at::autocast::set_enabled(enabled) is deprecated. Please use at::autocast::set_autocast_enabled(at::kCUDA, enabled) instead.")
TORCH_API inline void set_enabled(bool enabled) {
TORCH_WARN_DEPRECATION(
"at::autocast::",
__func__,
"(enabled) is deprecated. Please use at::autocast::set_autocast_enabled(at::kCUDA, enabled) instead.")
set_autocast_enabled(at::kCUDA, enabled);
}
C10_DEPRECATED_MESSAGE(
"at::autocast::get_autocast_gpu_dtype() is deprecated. Please use at::autocast::get_autocast_dtype(at::kCUDA) instead.")
TORCH_API inline at::ScalarType get_autocast_gpu_dtype() {
TORCH_WARN_DEPRECATION(
"at::autocast::",
__func__,
"() is deprecated. Please use at::autocast::get_autocast_dtype(at::kCUDA) instead.")
return get_autocast_dtype(at::kCUDA);
}
C10_DEPRECATED_MESSAGE(
"at::autocast::set_autocast_gpu_dtype(dtype) is deprecated. Please use at::autocast::set_autocast_dtype(at::kCUDA, dtype) instead.")
TORCH_API inline void set_autocast_gpu_dtype(at::ScalarType dtype) {
TORCH_WARN_DEPRECATION(
"at::autocast::",
__func__,
"(dtype) is deprecated. Please use at::autocast::set_autocast_dtype(at::kCUDA, dtype) instead.")
set_autocast_dtype(at::kCUDA, dtype);
}
#define DECLARE_DEPRECATED_AUTOCAST_APIS(name, device_type) \
C10_DEPRECATED_MESSAGE( \
"at::autocast::is_" #name \
"_enabled() is deprecated. Please use at::autocast::is_autocast_enabled(" #device_type \
") instead.") \
TORCH_API inline bool is_##name##_enabled() { \
TORCH_WARN_DEPRECATION( \
"at::autocast::", \
__func__, \
"() is deprecated. Please use at::autocast::is_autocast_enabled(" #device_type \
") instead.") \
return is_autocast_enabled(device_type); \
} \
\
C10_DEPRECATED_MESSAGE( \
"at::autocast::set_" #name \
"_enabled(enabled) is deprecated. Please use at::autocast::set_autocast_enabled(" #device_type \
", enabled) instead.") \
TORCH_API inline void set_##name##_enabled(bool enabled) { \
TORCH_WARN_DEPRECATION( \
"at::autocast::", \
__func__, \
"(enabled) is deprecated. Please use at::autocast::set_autocast_enabled(" #device_type \
", enabled) instead.") \
set_autocast_enabled(device_type, enabled); \
} \
\
C10_DEPRECATED_MESSAGE( \
"at::autocast::get_autocast_" #name \
"_dtype() is deprecated. Please use at::autocast::get_autocast_dtype(" #device_type \
") instead.") \
TORCH_API inline at::ScalarType get_autocast_##name##_dtype() { \
TORCH_WARN_DEPRECATION( \
"at::autocast::", \
__func__, \
"() is deprecated. Please at::autocast::get_autocast_dtype(" #device_type \
") instead.") \
return get_autocast_dtype(device_type); \
} \
\
C10_DEPRECATED_MESSAGE( \
"at::autocast::set_autocast_" #name \
"_dtype(dtype) is deprecated. Please use at::autocast::set_autocast_dtype(" #device_type \
", dtype) instead.") \
TORCH_API inline void set_autocast_##name##_dtype(at::ScalarType dtype) { \
TORCH_WARN_DEPRECATION( \
"at::autocast::", \
__func__, \
"(dtype) is deprecated. Please use at::autocast::set_autocast_dtype(" #device_type \
", dtype) instead.") \
set_autocast_dtype(device_type, dtype); \
}
#define AT_FORALL_DEPRECATED_AUTOCAST_BAKCNEDS(_) \
_(cpu, at::kCPU) \
_(xpu, at::kXPU) \
_(xla, at::kXLA) \
_(hpu, at::kHPU) \
_(ipu, at::kIPU) \
_(privateuseone, at::kPrivateUse1)
// deprecated other backend specific autocast APIs
AT_FORALL_DEPRECATED_AUTOCAST_BAKCNEDS(DECLARE_DEPRECATED_AUTOCAST_APIS)
namespace {
inline bool is_autocast_eligible(
const Tensor& tensor,
@ -96,22 +176,12 @@ inline DispatchKey get_autocast_dispatch_key_from_device_type(
inline at::ScalarType get_lower_precision_fp_from_device_type(
c10::DeviceType device_type) {
switch (device_type) {
case c10::DeviceType::CUDA:
return get_autocast_gpu_dtype();
case c10::DeviceType::CPU:
return get_autocast_cpu_dtype();
case c10::DeviceType::XPU:
return get_autocast_xpu_dtype();
case c10::DeviceType::IPU:
return get_autocast_ipu_dtype();
case c10::DeviceType::HPU:
return get_autocast_hpu_dtype();
case c10::DeviceType::XLA:
return get_autocast_xla_dtype();
case c10::DeviceType::PrivateUse1:
return get_autocast_privateuseone_dtype();
default:
if (device_type == at::kCPU || device_type == at::kCUDA ||
device_type == at::kXPU || device_type == at::kIPU ||
device_type == at::kHPU || device_type == at::kXLA ||
device_type == at::kPrivateUse1) {
return get_autocast_dtype(device_type);
} else {
throw std::runtime_error(
"unknown device type for autocast in get_lower_precision_fp_from_device_type");
}

View File

@ -139,6 +139,7 @@ class TestPublicBindings(TestCase):
"Generator",
"GeneratorType",
"get_autocast_cpu_dtype",
"get_autocast_dtype",
"get_autocast_ipu_dtype",
"get_default_dtype",
"get_num_interop_threads",
@ -216,6 +217,7 @@ class TestPublicBindings(TestCase):
"set_anomaly_enabled",
"set_autocast_cache_enabled",
"set_autocast_cpu_dtype",
"set_autocast_dtype",
"set_autocast_ipu_dtype",
"set_autocast_cpu_enabled",
"set_autocast_ipu_enabled",

View File

@ -1249,8 +1249,16 @@ def is_grad_enabled() -> _bool: ...
def _set_fwd_grad_enabled(enabled: _bool) -> None: ...
def _is_fwd_grad_enabled() -> _bool: ...
def is_inference_mode_enabled() -> _bool: ...
@overload
def set_autocast_enabled(device_type: str, enabled: _bool) -> None: ...
@overload
def set_autocast_enabled(enabled: _bool) -> None: ...
@overload
def is_autocast_enabled(device_type: str) -> _bool: ...
@overload
def is_autocast_enabled() -> _bool: ...
def set_autocast_dtype(device_type: str, dtype: _dtype) -> None: ...
def get_autocast_dtype(device_type: str) -> _dtype: ...
def clear_autocast_cache() -> None: ...
def set_autocast_cpu_enabled(enabled: _bool) -> None: ...
def is_autocast_cpu_enabled() -> _bool: ...

View File

@ -595,21 +595,30 @@ class OutputGraph:
self.torch_function_enabled,
)
global_state["grad_enabled"] = (torch.set_grad_enabled, torch.is_grad_enabled())
def autocast_specific_backend(
device_type: str, func: Callable[[str, Any], None]
):
def decorator(value):
return func(device_type, value)
return decorator
global_state["autocast_enabled"] = (
torch.set_autocast_enabled,
torch.is_autocast_enabled(),
autocast_specific_backend("cuda", torch.set_autocast_enabled),
torch.is_autocast_enabled("cuda"),
)
global_state["autocast_cpu_enabled"] = (
torch.set_autocast_cpu_enabled,
torch.is_autocast_cpu_enabled(),
autocast_specific_backend("cpu", torch.set_autocast_enabled),
torch.is_autocast_enabled("cpu"),
)
global_state["autocast_gpu_dtype"] = (
torch.set_autocast_gpu_dtype,
torch.get_autocast_gpu_dtype(),
autocast_specific_backend("cuda", torch.set_autocast_dtype),
torch.get_autocast_dtype("cuda"),
)
global_state["autocast_cpu_dtype"] = (
torch.set_autocast_cpu_dtype,
torch.get_autocast_cpu_dtype(),
autocast_specific_backend("cpu", torch.set_autocast_dtype),
torch.get_autocast_dtype("cpu"),
)
global_state["autocast_cache_enabled"] = (
torch.set_autocast_cache_enabled,

View File

@ -79,10 +79,10 @@ def normalize_as_list(x):
def _get_autocast_states():
return [
torch.is_autocast_enabled(),
torch.is_autocast_cpu_enabled(),
torch.get_autocast_gpu_dtype(),
torch.get_autocast_cpu_dtype(),
torch.is_autocast_enabled("cuda"),
torch.is_autocast_enabled("cpu"),
torch.get_autocast_dtype("cuda"),
torch.get_autocast_dtype("cpu"),
torch.is_autocast_cache_enabled(),
]

View File

@ -474,24 +474,47 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) {
Py_RETURN_TRUE;
}
namespace torch {
namespace autograd {
namespace torch::autograd {
static PyObject* set_autocast_enabled(PyObject* _unused, PyObject* arg) {
static PyObject* set_autocast_enabled(
PyObject* _unused,
PyObject* args,
PyObject* kwargs) {
HANDLE_TH_ERRORS
TORCH_CHECK_TYPE(
PyBool_Check(arg),
"enabled must be a bool (got ",
Py_TYPE(arg)->tp_name,
")");
at::autocast::set_enabled(arg == Py_True);
static PythonArgParser parser(
{"set_autocast_enabled(c10::string_view device_type, bool enabled)",
"set_autocast_enabled(bool enabled)"}); // this signature is depracated.
ParsedArgs<2> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
// Set at::kCUDA as default value to prevent BC-breaking changes.
at::DeviceType device_type = at::kCUDA;
int enabled_id = 0;
if (r.idx == 0) {
device_type = at::Device(r.string(0)).type();
enabled_id = 1;
}
auto enabled = r.toBool(enabled_id);
at::autocast::set_autocast_enabled(device_type, enabled);
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
static PyObject* is_autocast_enabled(PyObject* _unused, PyObject* arg) {
static PyObject* is_autocast_enabled(
PyObject* _unused,
PyObject* args,
PyObject* kwargs) {
HANDLE_TH_ERRORS
if (at::autocast::is_enabled()) {
static PythonArgParser parser(
{"is_autocast_enabled(c10::string_view device_type)",
"is_autocast_enabled()"}); // this signature is depracated.
ParsedArgs<1> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
// Set at::kCUDA as default value to prevent BC-breaking changes.
at::DeviceType device_type = at::kCUDA;
if (r.idx == 0) {
device_type = at::Device(r.string(0)).type();
}
if (at::autocast::is_autocast_enabled(device_type)) {
Py_RETURN_TRUE;
} else {
Py_RETURN_FALSE;
@ -499,11 +522,48 @@ static PyObject* is_autocast_enabled(PyObject* _unused, PyObject* arg) {
END_HANDLE_TH_ERRORS
}
static PyObject* get_autocast_dtype(
PyObject* _unused,
PyObject* args,
PyObject* kwargs) {
HANDLE_TH_ERRORS
static PythonArgParser parser(
{"get_autocast_dtype(c10::string_view device_type)"});
ParsedArgs<1> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
auto device_type = at::Device(r.string(0)).type();
at::ScalarType current_dtype = at::autocast::get_autocast_dtype(device_type);
auto dtype = (PyObject*)torch::getTHPDtype(current_dtype);
Py_INCREF(dtype);
return dtype;
END_HANDLE_TH_ERRORS
}
static PyObject* set_autocast_dtype(
PyObject* _unused,
PyObject* args,
PyObject* kwargs) {
HANDLE_TH_ERRORS
static PythonArgParser parser(
{"set_autocast_dtype(c10::string_view device_type, ScalarType dtype)"});
ParsedArgs<2> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
auto device_type = at::Device(r.string(0)).type();
auto dtype = r.scalartype(1);
at::autocast::set_autocast_dtype(device_type, dtype);
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
static PyObject* is_any_autocast_enabled(PyObject* _unused, PyObject* arg) {
HANDLE_TH_ERRORS
if (at::autocast::is_enabled() || at::autocast::is_cpu_enabled() ||
at::autocast::is_xpu_enabled() || at::autocast::is_ipu_enabled() ||
at::autocast::is_xla_enabled() || at::autocast::is_hpu_enabled()) {
if (at::autocast::is_autocast_enabled(at::kCPU) ||
at::autocast::is_autocast_enabled(at::kCUDA) ||
at::autocast::is_autocast_enabled(at::kXPU) ||
at::autocast::is_autocast_enabled(at::kIPU) ||
at::autocast::is_autocast_enabled(at::kXLA) ||
at::autocast::is_autocast_enabled(at::kHPU) ||
at::autocast::is_autocast_enabled(at::kPrivateUse1)) {
Py_RETURN_TRUE;
} else {
Py_RETURN_FALSE;
@ -518,14 +578,18 @@ static PyObject* set_autocast_cpu_enabled(PyObject* _unused, PyObject* arg) {
"enabled must be a bool (got ",
Py_TYPE(arg)->tp_name,
")");
at::autocast::set_cpu_enabled(arg == Py_True);
TORCH_WARN_DEPRECATION(
"torch.set_autocast_cpu_enabled(enabled) is deprecated. Please use torch.set_autocast_enabled('cpu', enabled) instead.")
at::autocast::set_autocast_enabled(at::kCPU, arg == Py_True);
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
static PyObject* is_autocast_cpu_enabled(PyObject* _unused, PyObject* arg) {
HANDLE_TH_ERRORS
if (at::autocast::is_cpu_enabled()) {
TORCH_WARN_DEPRECATION(
"torch.is_autocast_cpu_enabled() is deprecated. Please use torch.is_autocast_enabled('cpu') instead.")
if (at::autocast::is_autocast_enabled(at::kCPU)) {
Py_RETURN_TRUE;
} else {
Py_RETURN_FALSE;
@ -540,14 +604,18 @@ static PyObject* set_autocast_ipu_enabled(PyObject* _unused, PyObject* arg) {
"enabled must be a bool (got ",
Py_TYPE(arg)->tp_name,
")");
at::autocast::set_ipu_enabled(arg == Py_True);
TORCH_WARN_DEPRECATION(
"torch.set_autocast_ipu_enabled(enabled) is deprecated. Please use torch.set_autocast_enabled('ipu', enabled) instead.")
at::autocast::set_autocast_enabled(at::kIPU, arg == Py_True);
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
static PyObject* is_autocast_ipu_enabled(PyObject* _unused, PyObject* arg) {
HANDLE_TH_ERRORS
if (at::autocast::is_ipu_enabled()) {
TORCH_WARN_DEPRECATION(
"torch.is_autocast_ipu_enabled() is deprecated. Please use torch.is_autocast_enabled('ipu') instead.")
if (at::autocast::is_autocast_enabled(at::kIPU)) {
Py_RETURN_TRUE;
} else {
Py_RETURN_FALSE;
@ -562,14 +630,18 @@ static PyObject* set_autocast_xla_enabled(PyObject* _unused, PyObject* arg) {
"enabled must be a bool (got ",
Py_TYPE(arg)->tp_name,
")");
at::autocast::set_xla_enabled(arg == Py_True);
TORCH_WARN_DEPRECATION(
"torch.set_autocast_xla_enabled(enabled) is deprecated. Please use torch.set_autocast_enabled('xla', enabled) instead.")
at::autocast::set_autocast_enabled(at::kXLA, arg == Py_True);
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
static PyObject* is_autocast_xla_enabled(PyObject* _unused, PyObject* arg) {
HANDLE_TH_ERRORS
if (at::autocast::is_xla_enabled()) {
TORCH_WARN_DEPRECATION(
"torch.is_autocast_xla_enabled() is deprecated. Please use torch.is_autocast_enabled('xla') instead.")
if (at::autocast::is_autocast_enabled(at::kXLA)) {
Py_RETURN_TRUE;
} else {
Py_RETURN_FALSE;
@ -584,8 +656,10 @@ static PyObject* set_autocast_gpu_dtype(PyObject* _unused, PyObject* arg) {
"dtype must be a torch.dtype (got ",
Py_TYPE(arg)->tp_name,
")");
TORCH_WARN_DEPRECATION(
"torch.set_autocast_gpu_dtype(dtype) is deprecated. Please use torch.set_autocast_dtype('cuda', dtype) instead.")
at::ScalarType targetType = reinterpret_cast<THPDtype*>(arg)->scalar_type;
at::autocast::set_autocast_gpu_dtype(targetType);
at::autocast::set_autocast_dtype(at::kCUDA, targetType);
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
@ -597,8 +671,10 @@ static PyObject* set_autocast_cpu_dtype(PyObject* _unused, PyObject* arg) {
"dtype must be a torch.dtype (got ",
Py_TYPE(arg)->tp_name,
")");
TORCH_WARN_DEPRECATION(
"torch.set_autocast_cpu_dtype(dtype) is deprecated. Please use torch.set_autocast_dtype('cpu', dtype) instead.")
at::ScalarType targetType = reinterpret_cast<THPDtype*>(arg)->scalar_type;
at::autocast::set_autocast_cpu_dtype(targetType);
at::autocast::set_autocast_dtype(at::kCPU, targetType);
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
@ -610,8 +686,10 @@ static PyObject* set_autocast_ipu_dtype(PyObject* _unused, PyObject* arg) {
"dtype must be a torch.dtype (got ",
Py_TYPE(arg)->tp_name,
")");
TORCH_WARN_DEPRECATION(
"torch.set_autocast_ipu_dtype(dtype) is deprecated. Please use torch.set_autocast_dtype('ipu', dtype) instead.")
at::ScalarType targetType = reinterpret_cast<THPDtype*>(arg)->scalar_type;
at::autocast::set_autocast_ipu_dtype(targetType);
at::autocast::set_autocast_dtype(at::kIPU, targetType);
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
@ -623,15 +701,19 @@ static PyObject* set_autocast_xla_dtype(PyObject* _unused, PyObject* arg) {
"dtype must be a torch.dtype (got ",
Py_TYPE(arg)->tp_name,
")");
TORCH_WARN_DEPRECATION(
"torch.set_autocast_xla_dtype(dtype) is deprecated. Please use torch.set_autocast_dtype('xla', dtype) instead.")
at::ScalarType targetType = reinterpret_cast<THPDtype*>(arg)->scalar_type;
at::autocast::set_autocast_xla_dtype(targetType);
at::autocast::set_autocast_dtype(at::kXLA, targetType);
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
static PyObject* get_autocast_gpu_dtype(PyObject* _unused, PyObject* arg) {
HANDLE_TH_ERRORS
at::ScalarType current_dtype = at::autocast::get_autocast_gpu_dtype();
TORCH_WARN_DEPRECATION(
"torch.get_autocast_gpu_dtype() is deprecated. Please use torch.get_autocast_dtype('cuda') instead.")
at::ScalarType current_dtype = at::autocast::get_autocast_dtype(at::kCUDA);
auto dtype = (PyObject*)torch::getTHPDtype(current_dtype);
Py_INCREF(dtype);
return dtype;
@ -640,7 +722,9 @@ static PyObject* get_autocast_gpu_dtype(PyObject* _unused, PyObject* arg) {
static PyObject* get_autocast_cpu_dtype(PyObject* _unused, PyObject* arg) {
HANDLE_TH_ERRORS
at::ScalarType current_dtype = at::autocast::get_autocast_cpu_dtype();
TORCH_WARN_DEPRECATION(
"torch.get_autocast_cpu_dtype() is deprecated. Please use torch.get_autocast_dtype('cpu') instead.")
at::ScalarType current_dtype = at::autocast::get_autocast_dtype(at::kCPU);
auto dtype = (PyObject*)torch::getTHPDtype(current_dtype);
Py_INCREF(dtype);
return dtype;
@ -649,7 +733,9 @@ static PyObject* get_autocast_cpu_dtype(PyObject* _unused, PyObject* arg) {
static PyObject* get_autocast_ipu_dtype(PyObject* _unused, PyObject* arg) {
HANDLE_TH_ERRORS
at::ScalarType current_dtype = at::autocast::get_autocast_ipu_dtype();
TORCH_WARN_DEPRECATION(
"torch.get_autocast_ipu_dtype() is deprecated. Please use torch.get_autocast_dtype('ipu') instead.")
at::ScalarType current_dtype = at::autocast::get_autocast_dtype(at::kIPU);
auto dtype = (PyObject*)torch::getTHPDtype(current_dtype);
Py_INCREF(dtype);
return dtype;
@ -658,7 +744,9 @@ static PyObject* get_autocast_ipu_dtype(PyObject* _unused, PyObject* arg) {
static PyObject* get_autocast_xla_dtype(PyObject* _unused, PyObject* arg) {
HANDLE_TH_ERRORS
at::ScalarType current_dtype = at::autocast::get_autocast_xla_dtype();
TORCH_WARN_DEPRECATION(
"torch.get_autocast_xla_dtype() is deprecated. Please use torch.get_autocast_dtype('xla') instead.")
at::ScalarType current_dtype = at::autocast::get_autocast_dtype(at::kXLA);
auto dtype = (PyObject*)torch::getTHPDtype(current_dtype);
Py_INCREF(dtype);
return dtype;
@ -1123,8 +1211,22 @@ static PyMethodDef methods[] = { // NOLINT
is_inference_mode_enabled,
METH_NOARGS,
nullptr},
{"set_autocast_enabled", set_autocast_enabled, METH_O, nullptr},
{"is_autocast_enabled", is_autocast_enabled, METH_NOARGS, nullptr},
{"set_autocast_enabled",
castPyCFunctionWithKeywords(set_autocast_enabled),
METH_VARARGS | METH_KEYWORDS,
nullptr},
{"is_autocast_enabled",
castPyCFunctionWithKeywords(is_autocast_enabled),
METH_VARARGS | METH_KEYWORDS,
nullptr},
{"set_autocast_dtype",
castPyCFunctionWithKeywords(set_autocast_dtype),
METH_VARARGS | METH_KEYWORDS,
nullptr},
{"get_autocast_dtype",
castPyCFunctionWithKeywords(get_autocast_dtype),
METH_VARARGS | METH_KEYWORDS,
nullptr},
{"_is_any_autocast_enabled", is_any_autocast_enabled, METH_NOARGS, nullptr},
{"clear_autocast_cache", clear_autocast_cache, METH_NOARGS, nullptr},
{"set_autocast_cpu_enabled", set_autocast_cpu_enabled, METH_O, nullptr},
@ -1225,5 +1327,4 @@ PyMethodDef* python_functions() {
return methods;
}
} // namespace autograd
} // namespace torch
} // namespace torch::autograd

View File

@ -116,8 +116,8 @@ GraphFunction::SpecializationKey GraphFunction::currentSpecialization() const {
// disabling autodiff pass for mobile build since autocast APIs don't exist
return SpecializationKey::AutocastOff;
#else
bool cpu_enabled = at::autocast::is_cpu_enabled();
bool gpu_enabled = at::autocast::is_enabled();
bool cpu_enabled = at::autocast::is_autocast_enabled(at::kCPU);
bool gpu_enabled = at::autocast::is_autocast_enabled(at::kCUDA);
if (cpu_enabled && gpu_enabled) {
return SpecializationKey::CpuGpuAutocastOn;
} else if (!cpu_enabled && !gpu_enabled) {

View File

@ -521,10 +521,10 @@ void Autocast(const std::shared_ptr<Graph>& graph) {
GRAPH_DUMP("\nBefore Autocast: ", graph);
if (autocastEnabled()) {
AutocastContext init = {
at::autocast::is_enabled(),
at::autocast::is_cpu_enabled(),
at::autocast::get_autocast_gpu_dtype(),
at::autocast::get_autocast_cpu_dtype()};
at::autocast::is_autocast_enabled(at::kCUDA),
at::autocast::is_autocast_enabled(at::kCPU),
at::autocast::get_autocast_dtype(at::kCUDA),
at::autocast::get_autocast_dtype(at::kCPU)};
handleBlock(graph->block(), init);
}
GRAPH_DUMP("\nAfter Autocast: ", graph);

View File

@ -799,7 +799,7 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
#if defined BUILD_LITE_INTERPRETER || defined C10_MOBILE
bool enabled = false;
#else
bool enabled = at::autocast::is_enabled();
bool enabled = at::autocast::is_autocast_enabled(at::kCUDA);
#endif
push(stack, enabled);
},
@ -810,7 +810,7 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
#if defined BUILD_LITE_INTERPRETER || defined C10_MOBILE
bool enabled = false;
#else
bool enabled = at::autocast::is_cpu_enabled();
bool enabled = at::autocast::is_autocast_enabled(at::kCPU);
#endif
push(stack, enabled);
},

View File

@ -256,6 +256,8 @@ def get_ignored_functions() -> Set[Callable]:
handle_torch_function,
torch.set_autocast_enabled,
torch.is_autocast_enabled,
torch.set_autocast_dtype,
torch.get_autocast_dtype,
torch.clear_autocast_cache,
torch.set_autocast_cpu_enabled,
torch.is_autocast_cpu_enabled,

View File

@ -194,25 +194,18 @@ def set_device_states(devices, states) -> None:
def _get_autocast_kwargs(device="cuda"):
if device == "cuda":
if _supports_autocast(device):
device_autocast_kwargs = {
"enabled": torch.is_autocast_enabled(),
"dtype": torch.get_autocast_gpu_dtype(),
"cache_enabled": torch.is_autocast_cache_enabled(),
}
elif _supports_autocast(device):
device_module = _get_device_module(device)
device_autocast_kwargs = {
"enabled": device_module.is_autocast_enabled(),
"dtype": device_module.get_autocast_dtype(),
"enabled": torch.is_autocast_enabled(device),
"dtype": torch.get_autocast_dtype(device),
"cache_enabled": torch.is_autocast_cache_enabled(),
}
else:
device_autocast_kwargs = None
cpu_autocast_kwargs = {
"enabled": torch.is_autocast_cpu_enabled(),
"dtype": torch.get_autocast_cpu_dtype(),
"enabled": torch.is_autocast_enabled('cpu'),
"dtype": torch.get_autocast_dtype('cpu'),
"cache_enabled": torch.is_autocast_cache_enabled(),
}