mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Refactor autocast C++ APIs to be device-agnostic (#124359)
# Motivation This PR aims to refactor autocast **C++** APIs to be device-agnostic and deprecate the device-specific autocast **C++** APIs. In C++ side, - `is_enabled()` -> `is_enabled(device_type)`. - `set_enabled(new_enabled)` -> `set_enabled(device_type, new_enabled)`. - `get_autocast_dtype()` -> `get_autocast_dtype(device_type)` - `set_autocast_dtype(dtype)` -> `set_autocast_dtype(device_type, dtype)` These following C++ APIs are deprecated and should be removed in PyTorch 2.5 - `is_cpu_enabled` - `set_cpu_enabled` - `get_autocast_cpu_dtype` - `set_autocast_cpu_dtype` - `is_xpu_enabled` - `set_xpu_enabled` - `get_autocast_xpu_dtype` - `set_autocast_xpu_dtype` - `is_ipu_enabled` - `set_ipu_enabled` - `get_autocast_ipu_dtype` - `set_autocast_ipu_dtype` - `is_hpu_enabled` - `set_hpu_enabled` - `get_autocast_hpu_dtype` - `set_autocast_hpu_dtype` - `is_xla_enabled` - `set_xla_enabled` - `get_autocast_xla_dtype` - `set_autocast_xla_dtype` - `is_privateuseone_enabled` - `set_privateuseone_enabled` - `get_autocast_privateuseone_dtype` - `set_autocast_privateuseone_dtype` In Python side, provide 4 generic autocast APIs: - `torch.is_autocast_enabled(device_type)` - `torch.set_autocast_enabled(device_type, new_enabled)` - `torch.get_autocast_dtype(device_type)` - `torch.set_autocast_dtype(device_type, dtype)` # Additional Context We will submit another PR to refactor autocast **Python** APIs based on this PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/124359 Approved by: https://github.com/jgong5, https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
3c964ad1ca
commit
25f321b84f
@ -6,60 +6,14 @@
|
||||
|
||||
namespace at::autocast {
|
||||
|
||||
bool is_enabled() {
|
||||
return !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::AutocastCUDA);
|
||||
bool is_autocast_enabled(at::DeviceType device_type) {
|
||||
at::DispatchKey dispatch_key = get_autocast_dispatch_key_from_device_type(device_type);
|
||||
return !c10::impl::tls_is_dispatch_key_excluded(dispatch_key);
|
||||
}
|
||||
|
||||
void set_enabled(bool new_enabled) {
|
||||
c10::impl::tls_set_dispatch_key_excluded(DispatchKey::AutocastCUDA, !new_enabled);
|
||||
}
|
||||
|
||||
bool is_cpu_enabled() {
|
||||
return !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::AutocastCPU);
|
||||
}
|
||||
|
||||
void set_cpu_enabled(bool new_enabled) {
|
||||
c10::impl::tls_set_dispatch_key_excluded(DispatchKey::AutocastCPU, !new_enabled);
|
||||
}
|
||||
|
||||
bool is_xpu_enabled() {
|
||||
return !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::AutocastXPU);
|
||||
}
|
||||
|
||||
void set_xpu_enabled(bool new_enabled) {
|
||||
c10::impl::tls_set_dispatch_key_excluded(DispatchKey::AutocastXPU, !new_enabled);
|
||||
}
|
||||
|
||||
bool is_ipu_enabled() {
|
||||
return !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::AutocastIPU);
|
||||
}
|
||||
|
||||
void set_ipu_enabled(bool new_enabled) {
|
||||
c10::impl::tls_set_dispatch_key_excluded(DispatchKey::AutocastIPU, !new_enabled);
|
||||
}
|
||||
|
||||
bool is_hpu_enabled() {
|
||||
return !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::AutocastHPU);
|
||||
}
|
||||
|
||||
void set_hpu_enabled(bool new_enabled) {
|
||||
c10::impl::tls_set_dispatch_key_excluded(DispatchKey::AutocastHPU, !new_enabled);
|
||||
}
|
||||
|
||||
bool is_xla_enabled() {
|
||||
return !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::AutocastXLA);
|
||||
}
|
||||
|
||||
void set_xla_enabled(bool new_enabled) {
|
||||
c10::impl::tls_set_dispatch_key_excluded(DispatchKey::AutocastXLA, !new_enabled);
|
||||
}
|
||||
|
||||
bool is_privateuseone_enabled() {
|
||||
return !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::AutocastPrivateUse1);
|
||||
}
|
||||
|
||||
void set_privateuseone_enabled(bool new_enabled) {
|
||||
c10::impl::tls_set_dispatch_key_excluded(DispatchKey::AutocastPrivateUse1, !new_enabled);
|
||||
void set_autocast_enabled(at::DeviceType device_type, bool enabled) {
|
||||
at::DispatchKey dispatch_key = get_autocast_dispatch_key_from_device_type(device_type);
|
||||
c10::impl::tls_set_dispatch_key_excluded(dispatch_key, !enabled);
|
||||
}
|
||||
|
||||
namespace {
|
||||
@ -91,30 +45,40 @@ std::mutex cached_casts_mutex;
|
||||
// it calls clear_cache() to ensure cached Tensors don't leak outside the autocasting region.
|
||||
thread_local int nesting = 0;
|
||||
|
||||
// autocast_cpu_dtype is the lower_precision_fp used by AutocastCPU.
|
||||
thread_local at::ScalarType autocast_cpu_dtype = at::kBFloat16;
|
||||
|
||||
// autocast_xpu_dtype is the lower_precision_fp used by AutocastXPU.
|
||||
thread_local at::ScalarType autocast_xpu_dtype = at::kBFloat16;
|
||||
|
||||
// autocast_ipu_dtype is the lower_precision_fp used by AutocastIPU.
|
||||
thread_local at::ScalarType autocast_ipu_dtype = at::kHalf;
|
||||
|
||||
// autocast_hpu_dtype is the lower_precision_fp used by AutocastHPU.
|
||||
thread_local at::ScalarType autocast_hpu_dtype = at::kBFloat16;
|
||||
|
||||
// autocast_xla_dtype is the lower_precision_fp used by AutocastXLA.
|
||||
thread_local at::ScalarType autocast_xla_dtype = at::kBFloat16;
|
||||
// The order of this array MUST exactly match the definition order of DeviceType
|
||||
// in c10/core/DeviceType.h.
|
||||
static_assert(
|
||||
at::COMPILE_TIME_MAX_DEVICE_TYPES == 21,
|
||||
"The definition of the default autocast data type per device backend doesn't match with the definition of the device type.");
|
||||
thread_local std::array<at::ScalarType, at::COMPILE_TIME_MAX_DEVICE_TYPES>
|
||||
autocast_dtype = {
|
||||
at::kBFloat16, // CPU
|
||||
at::kHalf, // CUDA.
|
||||
at::ScalarType::Undefined, // Reserved for explicit MKLDNN
|
||||
at::ScalarType::Undefined, // OpenGL
|
||||
at::ScalarType::Undefined, // OpenCL
|
||||
at::ScalarType::Undefined, // IDEEP.
|
||||
at::kHalf, // AMD HIP
|
||||
at::ScalarType::Undefined, // FPGA
|
||||
at::ScalarType::Undefined, // ONNX Runtime / Microsoft
|
||||
at::kBFloat16, // XLA / TPU
|
||||
at::ScalarType::Undefined, // Vulkan
|
||||
at::ScalarType::Undefined, // Metal
|
||||
at::kBFloat16, // XPU
|
||||
at::ScalarType::Undefined, // MPS
|
||||
at::ScalarType::Undefined, // Meta (tensors with no data)
|
||||
at::kBFloat16, // HPU / HABANA
|
||||
at::ScalarType::Undefined, // SX-Aurora / NEC
|
||||
at::ScalarType::Undefined, // Lazy Tensors
|
||||
at::kHalf, // Graphcore IPU
|
||||
at::ScalarType::Undefined, // Meta training and inference devices
|
||||
at::kHalf, // PrivateUse1 device
|
||||
};
|
||||
|
||||
// should we enabled the cache inside autocast.
|
||||
thread_local bool cache_enabled = true;
|
||||
|
||||
// autocast_gpu_dtype is the lower_precision_fp used by AutocastGPU.
|
||||
thread_local at::ScalarType autocast_gpu_dtype = at::kHalf;
|
||||
|
||||
// autocast_privateuseone_dtype is the lower_precision_fp used by AutocastPrivateUse1.
|
||||
thread_local at::ScalarType autocast_privateuseone_dtype = at::kHalf;
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
void clear_cache() {
|
||||
const std::lock_guard<std::mutex> lock(cached_casts_mutex);
|
||||
@ -129,60 +93,12 @@ int decrement_nesting() {
|
||||
return --nesting;
|
||||
}
|
||||
|
||||
at::ScalarType get_autocast_gpu_dtype() {
|
||||
return autocast_gpu_dtype;
|
||||
at::ScalarType get_autocast_dtype(at::DeviceType device_type) {
|
||||
return autocast_dtype[static_cast<int>(device_type)];
|
||||
}
|
||||
|
||||
at::ScalarType get_autocast_cpu_dtype() {
|
||||
return autocast_cpu_dtype;
|
||||
}
|
||||
|
||||
at::ScalarType get_autocast_xpu_dtype() {
|
||||
return autocast_xpu_dtype;
|
||||
}
|
||||
|
||||
at::ScalarType get_autocast_ipu_dtype() {
|
||||
return autocast_ipu_dtype;
|
||||
}
|
||||
|
||||
at::ScalarType get_autocast_hpu_dtype() {
|
||||
return autocast_hpu_dtype;
|
||||
}
|
||||
|
||||
at::ScalarType get_autocast_xla_dtype() {
|
||||
return autocast_xla_dtype;
|
||||
}
|
||||
|
||||
at::ScalarType get_autocast_privateuseone_dtype() {
|
||||
return autocast_privateuseone_dtype;
|
||||
}
|
||||
|
||||
void set_autocast_cpu_dtype(at::ScalarType dtype) {
|
||||
autocast_cpu_dtype = dtype;
|
||||
}
|
||||
|
||||
void set_autocast_gpu_dtype(at::ScalarType dtype) {
|
||||
autocast_gpu_dtype = dtype;
|
||||
}
|
||||
|
||||
void set_autocast_xpu_dtype(at::ScalarType dtype) {
|
||||
autocast_xpu_dtype = dtype;
|
||||
}
|
||||
|
||||
void set_autocast_ipu_dtype(at::ScalarType dtype) {
|
||||
autocast_ipu_dtype = dtype;
|
||||
}
|
||||
|
||||
void set_autocast_hpu_dtype(at::ScalarType dtype) {
|
||||
autocast_hpu_dtype = dtype;
|
||||
}
|
||||
|
||||
void set_autocast_xla_dtype(at::ScalarType dtype) {
|
||||
autocast_xla_dtype = dtype;
|
||||
}
|
||||
|
||||
void set_autocast_privateuseone_dtype(at::ScalarType dtype) {
|
||||
autocast_privateuseone_dtype = dtype;
|
||||
void set_autocast_dtype(at::DeviceType device_type, at::ScalarType dtype) {
|
||||
autocast_dtype[static_cast<int>(device_type)] = dtype;
|
||||
}
|
||||
|
||||
bool is_autocast_cache_enabled() {
|
||||
|
@ -10,40 +10,120 @@
|
||||
|
||||
namespace at::autocast {
|
||||
|
||||
TORCH_API bool is_enabled();
|
||||
TORCH_API void set_enabled(bool enabled);
|
||||
TORCH_API bool is_autocast_enabled(at::DeviceType device_type);
|
||||
TORCH_API void set_autocast_enabled(at::DeviceType device_type, bool enabled);
|
||||
TORCH_API at::ScalarType get_autocast_dtype(at::DeviceType device_type);
|
||||
TORCH_API void set_autocast_dtype(
|
||||
at::DeviceType device_type,
|
||||
at::ScalarType dtype);
|
||||
TORCH_API void clear_cache();
|
||||
TORCH_API int increment_nesting();
|
||||
TORCH_API int decrement_nesting();
|
||||
TORCH_API bool is_cpu_enabled();
|
||||
TORCH_API void set_cpu_enabled(bool enabled);
|
||||
TORCH_API at::ScalarType get_autocast_gpu_dtype();
|
||||
TORCH_API at::ScalarType get_autocast_cpu_dtype();
|
||||
TORCH_API void set_autocast_gpu_dtype(at::ScalarType dtype);
|
||||
TORCH_API void set_autocast_cpu_dtype(at::ScalarType dtype);
|
||||
TORCH_API bool is_xpu_enabled();
|
||||
TORCH_API void set_xpu_enabled(bool enabled);
|
||||
TORCH_API at::ScalarType get_autocast_xpu_dtype();
|
||||
TORCH_API void set_autocast_xpu_dtype(at::ScalarType dtype);
|
||||
TORCH_API bool is_ipu_enabled();
|
||||
TORCH_API void set_ipu_enabled(bool enabled);
|
||||
TORCH_API at::ScalarType get_autocast_ipu_dtype();
|
||||
TORCH_API void set_autocast_ipu_dtype(at::ScalarType dtype);
|
||||
TORCH_API bool is_hpu_enabled();
|
||||
TORCH_API void set_hpu_enabled(bool enabled);
|
||||
TORCH_API at::ScalarType get_autocast_hpu_dtype();
|
||||
TORCH_API void set_autocast_hpu_dtype(at::ScalarType dtype);
|
||||
TORCH_API bool is_xla_enabled();
|
||||
TORCH_API void set_xla_enabled(bool enabled);
|
||||
TORCH_API at::ScalarType get_autocast_xla_dtype();
|
||||
TORCH_API void set_autocast_xla_dtype(at::ScalarType dtype);
|
||||
TORCH_API bool is_privateuseone_enabled();
|
||||
TORCH_API void set_privateuseone_enabled(bool enabled);
|
||||
TORCH_API at::ScalarType get_autocast_privateuseone_dtype();
|
||||
TORCH_API void set_autocast_privateuseone_dtype(at::ScalarType dtype);
|
||||
TORCH_API bool is_autocast_cache_enabled();
|
||||
TORCH_API void set_autocast_cache_enabled(bool enabled);
|
||||
|
||||
// deprecated CUDA-specific autocast APIs
|
||||
C10_DEPRECATED_MESSAGE(
|
||||
"at::autocast::is_enabled() is deprecated. Please use at::autocast::is_autocast_enabled(at::kCUDA) instead.")
|
||||
TORCH_API inline bool is_enabled() {
|
||||
TORCH_WARN_DEPRECATION(
|
||||
"at::autocast::",
|
||||
__func__,
|
||||
"() is deprecated. Please use at::autocast::is_autocast_enabled(at::kCUDA) instead.")
|
||||
return is_autocast_enabled(at::kCUDA);
|
||||
}
|
||||
C10_DEPRECATED_MESSAGE(
|
||||
"at::autocast::set_enabled(enabled) is deprecated. Please use at::autocast::set_autocast_enabled(at::kCUDA, enabled) instead.")
|
||||
TORCH_API inline void set_enabled(bool enabled) {
|
||||
TORCH_WARN_DEPRECATION(
|
||||
"at::autocast::",
|
||||
__func__,
|
||||
"(enabled) is deprecated. Please use at::autocast::set_autocast_enabled(at::kCUDA, enabled) instead.")
|
||||
set_autocast_enabled(at::kCUDA, enabled);
|
||||
}
|
||||
C10_DEPRECATED_MESSAGE(
|
||||
"at::autocast::get_autocast_gpu_dtype() is deprecated. Please use at::autocast::get_autocast_dtype(at::kCUDA) instead.")
|
||||
TORCH_API inline at::ScalarType get_autocast_gpu_dtype() {
|
||||
TORCH_WARN_DEPRECATION(
|
||||
"at::autocast::",
|
||||
__func__,
|
||||
"() is deprecated. Please use at::autocast::get_autocast_dtype(at::kCUDA) instead.")
|
||||
return get_autocast_dtype(at::kCUDA);
|
||||
}
|
||||
C10_DEPRECATED_MESSAGE(
|
||||
"at::autocast::set_autocast_gpu_dtype(dtype) is deprecated. Please use at::autocast::set_autocast_dtype(at::kCUDA, dtype) instead.")
|
||||
TORCH_API inline void set_autocast_gpu_dtype(at::ScalarType dtype) {
|
||||
TORCH_WARN_DEPRECATION(
|
||||
"at::autocast::",
|
||||
__func__,
|
||||
"(dtype) is deprecated. Please use at::autocast::set_autocast_dtype(at::kCUDA, dtype) instead.")
|
||||
set_autocast_dtype(at::kCUDA, dtype);
|
||||
}
|
||||
|
||||
#define DECLARE_DEPRECATED_AUTOCAST_APIS(name, device_type) \
|
||||
C10_DEPRECATED_MESSAGE( \
|
||||
"at::autocast::is_" #name \
|
||||
"_enabled() is deprecated. Please use at::autocast::is_autocast_enabled(" #device_type \
|
||||
") instead.") \
|
||||
TORCH_API inline bool is_##name##_enabled() { \
|
||||
TORCH_WARN_DEPRECATION( \
|
||||
"at::autocast::", \
|
||||
__func__, \
|
||||
"() is deprecated. Please use at::autocast::is_autocast_enabled(" #device_type \
|
||||
") instead.") \
|
||||
return is_autocast_enabled(device_type); \
|
||||
} \
|
||||
\
|
||||
C10_DEPRECATED_MESSAGE( \
|
||||
"at::autocast::set_" #name \
|
||||
"_enabled(enabled) is deprecated. Please use at::autocast::set_autocast_enabled(" #device_type \
|
||||
", enabled) instead.") \
|
||||
TORCH_API inline void set_##name##_enabled(bool enabled) { \
|
||||
TORCH_WARN_DEPRECATION( \
|
||||
"at::autocast::", \
|
||||
__func__, \
|
||||
"(enabled) is deprecated. Please use at::autocast::set_autocast_enabled(" #device_type \
|
||||
", enabled) instead.") \
|
||||
set_autocast_enabled(device_type, enabled); \
|
||||
} \
|
||||
\
|
||||
C10_DEPRECATED_MESSAGE( \
|
||||
"at::autocast::get_autocast_" #name \
|
||||
"_dtype() is deprecated. Please use at::autocast::get_autocast_dtype(" #device_type \
|
||||
") instead.") \
|
||||
TORCH_API inline at::ScalarType get_autocast_##name##_dtype() { \
|
||||
TORCH_WARN_DEPRECATION( \
|
||||
"at::autocast::", \
|
||||
__func__, \
|
||||
"() is deprecated. Please at::autocast::get_autocast_dtype(" #device_type \
|
||||
") instead.") \
|
||||
return get_autocast_dtype(device_type); \
|
||||
} \
|
||||
\
|
||||
C10_DEPRECATED_MESSAGE( \
|
||||
"at::autocast::set_autocast_" #name \
|
||||
"_dtype(dtype) is deprecated. Please use at::autocast::set_autocast_dtype(" #device_type \
|
||||
", dtype) instead.") \
|
||||
TORCH_API inline void set_autocast_##name##_dtype(at::ScalarType dtype) { \
|
||||
TORCH_WARN_DEPRECATION( \
|
||||
"at::autocast::", \
|
||||
__func__, \
|
||||
"(dtype) is deprecated. Please use at::autocast::set_autocast_dtype(" #device_type \
|
||||
", dtype) instead.") \
|
||||
set_autocast_dtype(device_type, dtype); \
|
||||
}
|
||||
|
||||
#define AT_FORALL_DEPRECATED_AUTOCAST_BAKCNEDS(_) \
|
||||
_(cpu, at::kCPU) \
|
||||
_(xpu, at::kXPU) \
|
||||
_(xla, at::kXLA) \
|
||||
_(hpu, at::kHPU) \
|
||||
_(ipu, at::kIPU) \
|
||||
_(privateuseone, at::kPrivateUse1)
|
||||
|
||||
// deprecated other backend specific autocast APIs
|
||||
AT_FORALL_DEPRECATED_AUTOCAST_BAKCNEDS(DECLARE_DEPRECATED_AUTOCAST_APIS)
|
||||
|
||||
namespace {
|
||||
inline bool is_autocast_eligible(
|
||||
const Tensor& tensor,
|
||||
@ -96,22 +176,12 @@ inline DispatchKey get_autocast_dispatch_key_from_device_type(
|
||||
|
||||
inline at::ScalarType get_lower_precision_fp_from_device_type(
|
||||
c10::DeviceType device_type) {
|
||||
switch (device_type) {
|
||||
case c10::DeviceType::CUDA:
|
||||
return get_autocast_gpu_dtype();
|
||||
case c10::DeviceType::CPU:
|
||||
return get_autocast_cpu_dtype();
|
||||
case c10::DeviceType::XPU:
|
||||
return get_autocast_xpu_dtype();
|
||||
case c10::DeviceType::IPU:
|
||||
return get_autocast_ipu_dtype();
|
||||
case c10::DeviceType::HPU:
|
||||
return get_autocast_hpu_dtype();
|
||||
case c10::DeviceType::XLA:
|
||||
return get_autocast_xla_dtype();
|
||||
case c10::DeviceType::PrivateUse1:
|
||||
return get_autocast_privateuseone_dtype();
|
||||
default:
|
||||
if (device_type == at::kCPU || device_type == at::kCUDA ||
|
||||
device_type == at::kXPU || device_type == at::kIPU ||
|
||||
device_type == at::kHPU || device_type == at::kXLA ||
|
||||
device_type == at::kPrivateUse1) {
|
||||
return get_autocast_dtype(device_type);
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"unknown device type for autocast in get_lower_precision_fp_from_device_type");
|
||||
}
|
||||
|
@ -139,6 +139,7 @@ class TestPublicBindings(TestCase):
|
||||
"Generator",
|
||||
"GeneratorType",
|
||||
"get_autocast_cpu_dtype",
|
||||
"get_autocast_dtype",
|
||||
"get_autocast_ipu_dtype",
|
||||
"get_default_dtype",
|
||||
"get_num_interop_threads",
|
||||
@ -216,6 +217,7 @@ class TestPublicBindings(TestCase):
|
||||
"set_anomaly_enabled",
|
||||
"set_autocast_cache_enabled",
|
||||
"set_autocast_cpu_dtype",
|
||||
"set_autocast_dtype",
|
||||
"set_autocast_ipu_dtype",
|
||||
"set_autocast_cpu_enabled",
|
||||
"set_autocast_ipu_enabled",
|
||||
|
@ -1249,8 +1249,16 @@ def is_grad_enabled() -> _bool: ...
|
||||
def _set_fwd_grad_enabled(enabled: _bool) -> None: ...
|
||||
def _is_fwd_grad_enabled() -> _bool: ...
|
||||
def is_inference_mode_enabled() -> _bool: ...
|
||||
@overload
|
||||
def set_autocast_enabled(device_type: str, enabled: _bool) -> None: ...
|
||||
@overload
|
||||
def set_autocast_enabled(enabled: _bool) -> None: ...
|
||||
@overload
|
||||
def is_autocast_enabled(device_type: str) -> _bool: ...
|
||||
@overload
|
||||
def is_autocast_enabled() -> _bool: ...
|
||||
def set_autocast_dtype(device_type: str, dtype: _dtype) -> None: ...
|
||||
def get_autocast_dtype(device_type: str) -> _dtype: ...
|
||||
def clear_autocast_cache() -> None: ...
|
||||
def set_autocast_cpu_enabled(enabled: _bool) -> None: ...
|
||||
def is_autocast_cpu_enabled() -> _bool: ...
|
||||
|
@ -595,21 +595,30 @@ class OutputGraph:
|
||||
self.torch_function_enabled,
|
||||
)
|
||||
global_state["grad_enabled"] = (torch.set_grad_enabled, torch.is_grad_enabled())
|
||||
|
||||
def autocast_specific_backend(
|
||||
device_type: str, func: Callable[[str, Any], None]
|
||||
):
|
||||
def decorator(value):
|
||||
return func(device_type, value)
|
||||
|
||||
return decorator
|
||||
|
||||
global_state["autocast_enabled"] = (
|
||||
torch.set_autocast_enabled,
|
||||
torch.is_autocast_enabled(),
|
||||
autocast_specific_backend("cuda", torch.set_autocast_enabled),
|
||||
torch.is_autocast_enabled("cuda"),
|
||||
)
|
||||
global_state["autocast_cpu_enabled"] = (
|
||||
torch.set_autocast_cpu_enabled,
|
||||
torch.is_autocast_cpu_enabled(),
|
||||
autocast_specific_backend("cpu", torch.set_autocast_enabled),
|
||||
torch.is_autocast_enabled("cpu"),
|
||||
)
|
||||
global_state["autocast_gpu_dtype"] = (
|
||||
torch.set_autocast_gpu_dtype,
|
||||
torch.get_autocast_gpu_dtype(),
|
||||
autocast_specific_backend("cuda", torch.set_autocast_dtype),
|
||||
torch.get_autocast_dtype("cuda"),
|
||||
)
|
||||
global_state["autocast_cpu_dtype"] = (
|
||||
torch.set_autocast_cpu_dtype,
|
||||
torch.get_autocast_cpu_dtype(),
|
||||
autocast_specific_backend("cpu", torch.set_autocast_dtype),
|
||||
torch.get_autocast_dtype("cpu"),
|
||||
)
|
||||
global_state["autocast_cache_enabled"] = (
|
||||
torch.set_autocast_cache_enabled,
|
||||
|
@ -79,10 +79,10 @@ def normalize_as_list(x):
|
||||
|
||||
def _get_autocast_states():
|
||||
return [
|
||||
torch.is_autocast_enabled(),
|
||||
torch.is_autocast_cpu_enabled(),
|
||||
torch.get_autocast_gpu_dtype(),
|
||||
torch.get_autocast_cpu_dtype(),
|
||||
torch.is_autocast_enabled("cuda"),
|
||||
torch.is_autocast_enabled("cpu"),
|
||||
torch.get_autocast_dtype("cuda"),
|
||||
torch.get_autocast_dtype("cpu"),
|
||||
torch.is_autocast_cache_enabled(),
|
||||
]
|
||||
|
||||
|
@ -474,24 +474,47 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) {
|
||||
Py_RETURN_TRUE;
|
||||
}
|
||||
|
||||
namespace torch {
|
||||
namespace autograd {
|
||||
namespace torch::autograd {
|
||||
|
||||
static PyObject* set_autocast_enabled(PyObject* _unused, PyObject* arg) {
|
||||
static PyObject* set_autocast_enabled(
|
||||
PyObject* _unused,
|
||||
PyObject* args,
|
||||
PyObject* kwargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
TORCH_CHECK_TYPE(
|
||||
PyBool_Check(arg),
|
||||
"enabled must be a bool (got ",
|
||||
Py_TYPE(arg)->tp_name,
|
||||
")");
|
||||
at::autocast::set_enabled(arg == Py_True);
|
||||
static PythonArgParser parser(
|
||||
{"set_autocast_enabled(c10::string_view device_type, bool enabled)",
|
||||
"set_autocast_enabled(bool enabled)"}); // this signature is depracated.
|
||||
ParsedArgs<2> parsed_args;
|
||||
auto r = parser.parse(args, kwargs, parsed_args);
|
||||
// Set at::kCUDA as default value to prevent BC-breaking changes.
|
||||
at::DeviceType device_type = at::kCUDA;
|
||||
int enabled_id = 0;
|
||||
if (r.idx == 0) {
|
||||
device_type = at::Device(r.string(0)).type();
|
||||
enabled_id = 1;
|
||||
}
|
||||
auto enabled = r.toBool(enabled_id);
|
||||
at::autocast::set_autocast_enabled(device_type, enabled);
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* is_autocast_enabled(PyObject* _unused, PyObject* arg) {
|
||||
static PyObject* is_autocast_enabled(
|
||||
PyObject* _unused,
|
||||
PyObject* args,
|
||||
PyObject* kwargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
if (at::autocast::is_enabled()) {
|
||||
static PythonArgParser parser(
|
||||
{"is_autocast_enabled(c10::string_view device_type)",
|
||||
"is_autocast_enabled()"}); // this signature is depracated.
|
||||
ParsedArgs<1> parsed_args;
|
||||
auto r = parser.parse(args, kwargs, parsed_args);
|
||||
// Set at::kCUDA as default value to prevent BC-breaking changes.
|
||||
at::DeviceType device_type = at::kCUDA;
|
||||
if (r.idx == 0) {
|
||||
device_type = at::Device(r.string(0)).type();
|
||||
}
|
||||
if (at::autocast::is_autocast_enabled(device_type)) {
|
||||
Py_RETURN_TRUE;
|
||||
} else {
|
||||
Py_RETURN_FALSE;
|
||||
@ -499,11 +522,48 @@ static PyObject* is_autocast_enabled(PyObject* _unused, PyObject* arg) {
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* get_autocast_dtype(
|
||||
PyObject* _unused,
|
||||
PyObject* args,
|
||||
PyObject* kwargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
static PythonArgParser parser(
|
||||
{"get_autocast_dtype(c10::string_view device_type)"});
|
||||
ParsedArgs<1> parsed_args;
|
||||
auto r = parser.parse(args, kwargs, parsed_args);
|
||||
auto device_type = at::Device(r.string(0)).type();
|
||||
at::ScalarType current_dtype = at::autocast::get_autocast_dtype(device_type);
|
||||
auto dtype = (PyObject*)torch::getTHPDtype(current_dtype);
|
||||
Py_INCREF(dtype);
|
||||
return dtype;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* set_autocast_dtype(
|
||||
PyObject* _unused,
|
||||
PyObject* args,
|
||||
PyObject* kwargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
static PythonArgParser parser(
|
||||
{"set_autocast_dtype(c10::string_view device_type, ScalarType dtype)"});
|
||||
ParsedArgs<2> parsed_args;
|
||||
auto r = parser.parse(args, kwargs, parsed_args);
|
||||
auto device_type = at::Device(r.string(0)).type();
|
||||
auto dtype = r.scalartype(1);
|
||||
at::autocast::set_autocast_dtype(device_type, dtype);
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* is_any_autocast_enabled(PyObject* _unused, PyObject* arg) {
|
||||
HANDLE_TH_ERRORS
|
||||
if (at::autocast::is_enabled() || at::autocast::is_cpu_enabled() ||
|
||||
at::autocast::is_xpu_enabled() || at::autocast::is_ipu_enabled() ||
|
||||
at::autocast::is_xla_enabled() || at::autocast::is_hpu_enabled()) {
|
||||
if (at::autocast::is_autocast_enabled(at::kCPU) ||
|
||||
at::autocast::is_autocast_enabled(at::kCUDA) ||
|
||||
at::autocast::is_autocast_enabled(at::kXPU) ||
|
||||
at::autocast::is_autocast_enabled(at::kIPU) ||
|
||||
at::autocast::is_autocast_enabled(at::kXLA) ||
|
||||
at::autocast::is_autocast_enabled(at::kHPU) ||
|
||||
at::autocast::is_autocast_enabled(at::kPrivateUse1)) {
|
||||
Py_RETURN_TRUE;
|
||||
} else {
|
||||
Py_RETURN_FALSE;
|
||||
@ -518,14 +578,18 @@ static PyObject* set_autocast_cpu_enabled(PyObject* _unused, PyObject* arg) {
|
||||
"enabled must be a bool (got ",
|
||||
Py_TYPE(arg)->tp_name,
|
||||
")");
|
||||
at::autocast::set_cpu_enabled(arg == Py_True);
|
||||
TORCH_WARN_DEPRECATION(
|
||||
"torch.set_autocast_cpu_enabled(enabled) is deprecated. Please use torch.set_autocast_enabled('cpu', enabled) instead.")
|
||||
at::autocast::set_autocast_enabled(at::kCPU, arg == Py_True);
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* is_autocast_cpu_enabled(PyObject* _unused, PyObject* arg) {
|
||||
HANDLE_TH_ERRORS
|
||||
if (at::autocast::is_cpu_enabled()) {
|
||||
TORCH_WARN_DEPRECATION(
|
||||
"torch.is_autocast_cpu_enabled() is deprecated. Please use torch.is_autocast_enabled('cpu') instead.")
|
||||
if (at::autocast::is_autocast_enabled(at::kCPU)) {
|
||||
Py_RETURN_TRUE;
|
||||
} else {
|
||||
Py_RETURN_FALSE;
|
||||
@ -540,14 +604,18 @@ static PyObject* set_autocast_ipu_enabled(PyObject* _unused, PyObject* arg) {
|
||||
"enabled must be a bool (got ",
|
||||
Py_TYPE(arg)->tp_name,
|
||||
")");
|
||||
at::autocast::set_ipu_enabled(arg == Py_True);
|
||||
TORCH_WARN_DEPRECATION(
|
||||
"torch.set_autocast_ipu_enabled(enabled) is deprecated. Please use torch.set_autocast_enabled('ipu', enabled) instead.")
|
||||
at::autocast::set_autocast_enabled(at::kIPU, arg == Py_True);
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* is_autocast_ipu_enabled(PyObject* _unused, PyObject* arg) {
|
||||
HANDLE_TH_ERRORS
|
||||
if (at::autocast::is_ipu_enabled()) {
|
||||
TORCH_WARN_DEPRECATION(
|
||||
"torch.is_autocast_ipu_enabled() is deprecated. Please use torch.is_autocast_enabled('ipu') instead.")
|
||||
if (at::autocast::is_autocast_enabled(at::kIPU)) {
|
||||
Py_RETURN_TRUE;
|
||||
} else {
|
||||
Py_RETURN_FALSE;
|
||||
@ -562,14 +630,18 @@ static PyObject* set_autocast_xla_enabled(PyObject* _unused, PyObject* arg) {
|
||||
"enabled must be a bool (got ",
|
||||
Py_TYPE(arg)->tp_name,
|
||||
")");
|
||||
at::autocast::set_xla_enabled(arg == Py_True);
|
||||
TORCH_WARN_DEPRECATION(
|
||||
"torch.set_autocast_xla_enabled(enabled) is deprecated. Please use torch.set_autocast_enabled('xla', enabled) instead.")
|
||||
at::autocast::set_autocast_enabled(at::kXLA, arg == Py_True);
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* is_autocast_xla_enabled(PyObject* _unused, PyObject* arg) {
|
||||
HANDLE_TH_ERRORS
|
||||
if (at::autocast::is_xla_enabled()) {
|
||||
TORCH_WARN_DEPRECATION(
|
||||
"torch.is_autocast_xla_enabled() is deprecated. Please use torch.is_autocast_enabled('xla') instead.")
|
||||
if (at::autocast::is_autocast_enabled(at::kXLA)) {
|
||||
Py_RETURN_TRUE;
|
||||
} else {
|
||||
Py_RETURN_FALSE;
|
||||
@ -584,8 +656,10 @@ static PyObject* set_autocast_gpu_dtype(PyObject* _unused, PyObject* arg) {
|
||||
"dtype must be a torch.dtype (got ",
|
||||
Py_TYPE(arg)->tp_name,
|
||||
")");
|
||||
TORCH_WARN_DEPRECATION(
|
||||
"torch.set_autocast_gpu_dtype(dtype) is deprecated. Please use torch.set_autocast_dtype('cuda', dtype) instead.")
|
||||
at::ScalarType targetType = reinterpret_cast<THPDtype*>(arg)->scalar_type;
|
||||
at::autocast::set_autocast_gpu_dtype(targetType);
|
||||
at::autocast::set_autocast_dtype(at::kCUDA, targetType);
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
@ -597,8 +671,10 @@ static PyObject* set_autocast_cpu_dtype(PyObject* _unused, PyObject* arg) {
|
||||
"dtype must be a torch.dtype (got ",
|
||||
Py_TYPE(arg)->tp_name,
|
||||
")");
|
||||
TORCH_WARN_DEPRECATION(
|
||||
"torch.set_autocast_cpu_dtype(dtype) is deprecated. Please use torch.set_autocast_dtype('cpu', dtype) instead.")
|
||||
at::ScalarType targetType = reinterpret_cast<THPDtype*>(arg)->scalar_type;
|
||||
at::autocast::set_autocast_cpu_dtype(targetType);
|
||||
at::autocast::set_autocast_dtype(at::kCPU, targetType);
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
@ -610,8 +686,10 @@ static PyObject* set_autocast_ipu_dtype(PyObject* _unused, PyObject* arg) {
|
||||
"dtype must be a torch.dtype (got ",
|
||||
Py_TYPE(arg)->tp_name,
|
||||
")");
|
||||
TORCH_WARN_DEPRECATION(
|
||||
"torch.set_autocast_ipu_dtype(dtype) is deprecated. Please use torch.set_autocast_dtype('ipu', dtype) instead.")
|
||||
at::ScalarType targetType = reinterpret_cast<THPDtype*>(arg)->scalar_type;
|
||||
at::autocast::set_autocast_ipu_dtype(targetType);
|
||||
at::autocast::set_autocast_dtype(at::kIPU, targetType);
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
@ -623,15 +701,19 @@ static PyObject* set_autocast_xla_dtype(PyObject* _unused, PyObject* arg) {
|
||||
"dtype must be a torch.dtype (got ",
|
||||
Py_TYPE(arg)->tp_name,
|
||||
")");
|
||||
TORCH_WARN_DEPRECATION(
|
||||
"torch.set_autocast_xla_dtype(dtype) is deprecated. Please use torch.set_autocast_dtype('xla', dtype) instead.")
|
||||
at::ScalarType targetType = reinterpret_cast<THPDtype*>(arg)->scalar_type;
|
||||
at::autocast::set_autocast_xla_dtype(targetType);
|
||||
at::autocast::set_autocast_dtype(at::kXLA, targetType);
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* get_autocast_gpu_dtype(PyObject* _unused, PyObject* arg) {
|
||||
HANDLE_TH_ERRORS
|
||||
at::ScalarType current_dtype = at::autocast::get_autocast_gpu_dtype();
|
||||
TORCH_WARN_DEPRECATION(
|
||||
"torch.get_autocast_gpu_dtype() is deprecated. Please use torch.get_autocast_dtype('cuda') instead.")
|
||||
at::ScalarType current_dtype = at::autocast::get_autocast_dtype(at::kCUDA);
|
||||
auto dtype = (PyObject*)torch::getTHPDtype(current_dtype);
|
||||
Py_INCREF(dtype);
|
||||
return dtype;
|
||||
@ -640,7 +722,9 @@ static PyObject* get_autocast_gpu_dtype(PyObject* _unused, PyObject* arg) {
|
||||
|
||||
static PyObject* get_autocast_cpu_dtype(PyObject* _unused, PyObject* arg) {
|
||||
HANDLE_TH_ERRORS
|
||||
at::ScalarType current_dtype = at::autocast::get_autocast_cpu_dtype();
|
||||
TORCH_WARN_DEPRECATION(
|
||||
"torch.get_autocast_cpu_dtype() is deprecated. Please use torch.get_autocast_dtype('cpu') instead.")
|
||||
at::ScalarType current_dtype = at::autocast::get_autocast_dtype(at::kCPU);
|
||||
auto dtype = (PyObject*)torch::getTHPDtype(current_dtype);
|
||||
Py_INCREF(dtype);
|
||||
return dtype;
|
||||
@ -649,7 +733,9 @@ static PyObject* get_autocast_cpu_dtype(PyObject* _unused, PyObject* arg) {
|
||||
|
||||
static PyObject* get_autocast_ipu_dtype(PyObject* _unused, PyObject* arg) {
|
||||
HANDLE_TH_ERRORS
|
||||
at::ScalarType current_dtype = at::autocast::get_autocast_ipu_dtype();
|
||||
TORCH_WARN_DEPRECATION(
|
||||
"torch.get_autocast_ipu_dtype() is deprecated. Please use torch.get_autocast_dtype('ipu') instead.")
|
||||
at::ScalarType current_dtype = at::autocast::get_autocast_dtype(at::kIPU);
|
||||
auto dtype = (PyObject*)torch::getTHPDtype(current_dtype);
|
||||
Py_INCREF(dtype);
|
||||
return dtype;
|
||||
@ -658,7 +744,9 @@ static PyObject* get_autocast_ipu_dtype(PyObject* _unused, PyObject* arg) {
|
||||
|
||||
static PyObject* get_autocast_xla_dtype(PyObject* _unused, PyObject* arg) {
|
||||
HANDLE_TH_ERRORS
|
||||
at::ScalarType current_dtype = at::autocast::get_autocast_xla_dtype();
|
||||
TORCH_WARN_DEPRECATION(
|
||||
"torch.get_autocast_xla_dtype() is deprecated. Please use torch.get_autocast_dtype('xla') instead.")
|
||||
at::ScalarType current_dtype = at::autocast::get_autocast_dtype(at::kXLA);
|
||||
auto dtype = (PyObject*)torch::getTHPDtype(current_dtype);
|
||||
Py_INCREF(dtype);
|
||||
return dtype;
|
||||
@ -1123,8 +1211,22 @@ static PyMethodDef methods[] = { // NOLINT
|
||||
is_inference_mode_enabled,
|
||||
METH_NOARGS,
|
||||
nullptr},
|
||||
{"set_autocast_enabled", set_autocast_enabled, METH_O, nullptr},
|
||||
{"is_autocast_enabled", is_autocast_enabled, METH_NOARGS, nullptr},
|
||||
{"set_autocast_enabled",
|
||||
castPyCFunctionWithKeywords(set_autocast_enabled),
|
||||
METH_VARARGS | METH_KEYWORDS,
|
||||
nullptr},
|
||||
{"is_autocast_enabled",
|
||||
castPyCFunctionWithKeywords(is_autocast_enabled),
|
||||
METH_VARARGS | METH_KEYWORDS,
|
||||
nullptr},
|
||||
{"set_autocast_dtype",
|
||||
castPyCFunctionWithKeywords(set_autocast_dtype),
|
||||
METH_VARARGS | METH_KEYWORDS,
|
||||
nullptr},
|
||||
{"get_autocast_dtype",
|
||||
castPyCFunctionWithKeywords(get_autocast_dtype),
|
||||
METH_VARARGS | METH_KEYWORDS,
|
||||
nullptr},
|
||||
{"_is_any_autocast_enabled", is_any_autocast_enabled, METH_NOARGS, nullptr},
|
||||
{"clear_autocast_cache", clear_autocast_cache, METH_NOARGS, nullptr},
|
||||
{"set_autocast_cpu_enabled", set_autocast_cpu_enabled, METH_O, nullptr},
|
||||
@ -1225,5 +1327,4 @@ PyMethodDef* python_functions() {
|
||||
return methods;
|
||||
}
|
||||
|
||||
} // namespace autograd
|
||||
} // namespace torch
|
||||
} // namespace torch::autograd
|
||||
|
@ -116,8 +116,8 @@ GraphFunction::SpecializationKey GraphFunction::currentSpecialization() const {
|
||||
// disabling autodiff pass for mobile build since autocast APIs don't exist
|
||||
return SpecializationKey::AutocastOff;
|
||||
#else
|
||||
bool cpu_enabled = at::autocast::is_cpu_enabled();
|
||||
bool gpu_enabled = at::autocast::is_enabled();
|
||||
bool cpu_enabled = at::autocast::is_autocast_enabled(at::kCPU);
|
||||
bool gpu_enabled = at::autocast::is_autocast_enabled(at::kCUDA);
|
||||
if (cpu_enabled && gpu_enabled) {
|
||||
return SpecializationKey::CpuGpuAutocastOn;
|
||||
} else if (!cpu_enabled && !gpu_enabled) {
|
||||
|
@ -521,10 +521,10 @@ void Autocast(const std::shared_ptr<Graph>& graph) {
|
||||
GRAPH_DUMP("\nBefore Autocast: ", graph);
|
||||
if (autocastEnabled()) {
|
||||
AutocastContext init = {
|
||||
at::autocast::is_enabled(),
|
||||
at::autocast::is_cpu_enabled(),
|
||||
at::autocast::get_autocast_gpu_dtype(),
|
||||
at::autocast::get_autocast_cpu_dtype()};
|
||||
at::autocast::is_autocast_enabled(at::kCUDA),
|
||||
at::autocast::is_autocast_enabled(at::kCPU),
|
||||
at::autocast::get_autocast_dtype(at::kCUDA),
|
||||
at::autocast::get_autocast_dtype(at::kCPU)};
|
||||
handleBlock(graph->block(), init);
|
||||
}
|
||||
GRAPH_DUMP("\nAfter Autocast: ", graph);
|
||||
|
@ -799,7 +799,7 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
|
||||
#if defined BUILD_LITE_INTERPRETER || defined C10_MOBILE
|
||||
bool enabled = false;
|
||||
#else
|
||||
bool enabled = at::autocast::is_enabled();
|
||||
bool enabled = at::autocast::is_autocast_enabled(at::kCUDA);
|
||||
#endif
|
||||
push(stack, enabled);
|
||||
},
|
||||
@ -810,7 +810,7 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs{
|
||||
#if defined BUILD_LITE_INTERPRETER || defined C10_MOBILE
|
||||
bool enabled = false;
|
||||
#else
|
||||
bool enabled = at::autocast::is_cpu_enabled();
|
||||
bool enabled = at::autocast::is_autocast_enabled(at::kCPU);
|
||||
#endif
|
||||
push(stack, enabled);
|
||||
},
|
||||
|
@ -256,6 +256,8 @@ def get_ignored_functions() -> Set[Callable]:
|
||||
handle_torch_function,
|
||||
torch.set_autocast_enabled,
|
||||
torch.is_autocast_enabled,
|
||||
torch.set_autocast_dtype,
|
||||
torch.get_autocast_dtype,
|
||||
torch.clear_autocast_cache,
|
||||
torch.set_autocast_cpu_enabled,
|
||||
torch.is_autocast_cpu_enabled,
|
||||
|
@ -194,25 +194,18 @@ def set_device_states(devices, states) -> None:
|
||||
|
||||
|
||||
def _get_autocast_kwargs(device="cuda"):
|
||||
if device == "cuda":
|
||||
if _supports_autocast(device):
|
||||
device_autocast_kwargs = {
|
||||
"enabled": torch.is_autocast_enabled(),
|
||||
"dtype": torch.get_autocast_gpu_dtype(),
|
||||
"cache_enabled": torch.is_autocast_cache_enabled(),
|
||||
}
|
||||
elif _supports_autocast(device):
|
||||
device_module = _get_device_module(device)
|
||||
device_autocast_kwargs = {
|
||||
"enabled": device_module.is_autocast_enabled(),
|
||||
"dtype": device_module.get_autocast_dtype(),
|
||||
"enabled": torch.is_autocast_enabled(device),
|
||||
"dtype": torch.get_autocast_dtype(device),
|
||||
"cache_enabled": torch.is_autocast_cache_enabled(),
|
||||
}
|
||||
else:
|
||||
device_autocast_kwargs = None
|
||||
|
||||
cpu_autocast_kwargs = {
|
||||
"enabled": torch.is_autocast_cpu_enabled(),
|
||||
"dtype": torch.get_autocast_cpu_dtype(),
|
||||
"enabled": torch.is_autocast_enabled('cpu'),
|
||||
"dtype": torch.get_autocast_dtype('cpu'),
|
||||
"cache_enabled": torch.is_autocast_cache_enabled(),
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user