mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[AMP] Support XLA:TPU (#96370)
With https://github.com/pytorch/xla/pull/5148, https://github.com/pytorch/xla/pull/4740 With these changes XLA:GPU users should use `torch.cuda.amp.autocast()` for AMP with float16 XLA:TPU users should use `torch.amp.autocast('xla')` for AMP with bfloat16 Pull Request resolved: https://github.com/pytorch/pytorch/pull/96370 Approved by: https://github.com/bdhirsh, https://github.com/malfet
This commit is contained in:
2
.github/ci_commit_pins/xla.txt
vendored
2
.github/ci_commit_pins/xla.txt
vendored
@ -1 +1 @@
|
||||
0eeaefb2341c3beea65545f278cdbd998f5a8399
|
||||
73392fc2a6c9ec40cba968ea66754514346ac79f
|
@ -49,6 +49,14 @@ void set_hpu_enabled(bool new_enabled) {
|
||||
c10::impl::tls_set_dispatch_key_excluded(DispatchKey::AutocastHPU, !new_enabled);
|
||||
}
|
||||
|
||||
bool is_xla_enabled() {
|
||||
return !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::AutocastXLA);
|
||||
}
|
||||
|
||||
void set_xla_enabled(bool new_enabled) {
|
||||
c10::impl::tls_set_dispatch_key_excluded(DispatchKey::AutocastXLA, !new_enabled);
|
||||
}
|
||||
|
||||
bool is_privateuseone_enabled() {
|
||||
return !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::AutocastPrivateUse1);
|
||||
}
|
||||
@ -98,6 +106,9 @@ thread_local at::ScalarType autocast_ipu_dtype = at::kHalf;
|
||||
// autocast_hpu_dtype is the lower_precision_fp used by AutocastHPU.
|
||||
thread_local at::ScalarType autocast_hpu_dtype = at::kBFloat16;
|
||||
|
||||
// autocast_xla_dtype is the lower_precision_fp used by AutocastXLA.
|
||||
thread_local at::ScalarType autocast_xla_dtype = at::kBFloat16;
|
||||
|
||||
// should we enabled the cache inside autocast.
|
||||
thread_local bool cache_enabled = true;
|
||||
|
||||
@ -141,6 +152,10 @@ at::ScalarType get_autocast_hpu_dtype() {
|
||||
return autocast_hpu_dtype;
|
||||
}
|
||||
|
||||
at::ScalarType get_autocast_xla_dtype() {
|
||||
return autocast_xla_dtype;
|
||||
}
|
||||
|
||||
at::ScalarType get_autocast_privateuseone_dtype() {
|
||||
return autocast_privateuseone_dtype;
|
||||
}
|
||||
@ -168,6 +183,10 @@ void set_autocast_hpu_dtype(at::ScalarType dtype) {
|
||||
autocast_hpu_dtype = dtype;
|
||||
}
|
||||
|
||||
void set_autocast_xla_dtype(at::ScalarType dtype) {
|
||||
autocast_xla_dtype = dtype;
|
||||
}
|
||||
|
||||
void set_autocast_privateuseone_dtype(at::ScalarType dtype) {
|
||||
autocast_privateuseone_dtype = dtype;
|
||||
}
|
||||
|
@ -34,6 +34,10 @@ TORCH_API bool is_hpu_enabled();
|
||||
TORCH_API void set_hpu_enabled(bool enabled);
|
||||
TORCH_API at::ScalarType get_autocast_hpu_dtype();
|
||||
TORCH_API void set_autocast_hpu_dtype(at::ScalarType dtype);
|
||||
TORCH_API bool is_xla_enabled();
|
||||
TORCH_API void set_xla_enabled(bool enabled);
|
||||
TORCH_API at::ScalarType get_autocast_xla_dtype();
|
||||
TORCH_API void set_autocast_xla_dtype(at::ScalarType dtype);
|
||||
TORCH_API bool is_privateuseone_enabled();
|
||||
TORCH_API void set_privateuseone_enabled(bool enabled);
|
||||
TORCH_API at::ScalarType get_autocast_privateuseone_dtype();
|
||||
@ -56,6 +60,8 @@ bool is_autocast_eligible(const Tensor& tensor, DeviceType device_type) {
|
||||
return tensor.is_ipu() && tensor.is_floating_point();
|
||||
case DeviceType::HPU:
|
||||
return tensor.is_hpu() && tensor.is_floating_point();
|
||||
case DeviceType::XLA:
|
||||
return tensor.is_xla() && tensor.is_floating_point();
|
||||
case DeviceType::PrivateUse1:
|
||||
return tensor.device().type() == DeviceType::PrivateUse1 &&
|
||||
tensor.is_floating_point();
|
||||
@ -78,6 +84,8 @@ inline DispatchKey get_autocast_dispatch_key_from_device_type(
|
||||
return DispatchKey::AutocastIPU;
|
||||
case DeviceType::HPU:
|
||||
return DispatchKey::AutocastHPU;
|
||||
case DeviceType::XLA:
|
||||
return DispatchKey::AutocastXLA;
|
||||
case DeviceType::PrivateUse1:
|
||||
return DispatchKey::AutocastPrivateUse1;
|
||||
default:
|
||||
@ -99,6 +107,8 @@ inline at::ScalarType get_lower_precision_fp_from_device_type(
|
||||
return get_autocast_ipu_dtype();
|
||||
case DeviceType::HPU:
|
||||
return get_autocast_hpu_dtype();
|
||||
case DeviceType::XLA:
|
||||
return get_autocast_xla_dtype();
|
||||
case DeviceType::PrivateUse1:
|
||||
return get_autocast_privateuseone_dtype();
|
||||
default:
|
||||
|
@ -231,6 +231,7 @@ namespace c10 {
|
||||
_(aten, has_torch_function) \
|
||||
_(aten, is_autocast_enabled) \
|
||||
_(aten, is_autocast_cpu_enabled) \
|
||||
_(aten, is_autocast_xla_enabled) \
|
||||
FORALL_ATEN_BASE_SYMBOLS(_) \
|
||||
_(onnx, Add) \
|
||||
_(onnx, Concat) \
|
||||
|
@ -146,6 +146,8 @@ const char* toString(DispatchKey t) {
|
||||
return "AutocastHPU";
|
||||
case DispatchKey::AutocastCUDA:
|
||||
return "AutocastCUDA";
|
||||
case DispatchKey::AutocastXLA:
|
||||
return "AutocastXLA";
|
||||
case DispatchKey::AutocastPrivateUse1:
|
||||
return "AutocastPrivateUse1";
|
||||
|
||||
@ -293,6 +295,7 @@ c10::DispatchKey parseDispatchKey(const std::string& k) {
|
||||
{"AutocastIPU", c10::DispatchKey::AutocastIPU},
|
||||
{"AutocastHPU", c10::DispatchKey::AutocastHPU},
|
||||
{"AutocastCUDA", c10::DispatchKey::AutocastCUDA},
|
||||
{"AutocastXLA", c10::DispatchKey::AutocastXLA},
|
||||
{"AutocastPrivateUse1", c10::DispatchKey::AutocastPrivateUse1},
|
||||
{"FuncTorchBatched", c10::DispatchKey::FuncTorchBatched},
|
||||
{"FuncTorchVmapMode", c10::DispatchKey::FuncTorchVmapMode},
|
||||
|
@ -355,8 +355,9 @@ enum class DispatchKey : uint16_t {
|
||||
AutocastXPU,
|
||||
AutocastIPU,
|
||||
AutocastHPU,
|
||||
// Naughtily, AutocastCUDA is also being used for XLA. In the terminal state,
|
||||
// it probably should get its own Autocast key
|
||||
AutocastXLA,
|
||||
// AutocastXLA is only being used for TPUs. XLA GPUs continue to use
|
||||
// AutocastCUDA.
|
||||
AutocastCUDA,
|
||||
AutocastPrivateUse1,
|
||||
|
||||
|
@ -645,6 +645,7 @@ constexpr DispatchKeySet autocast_dispatch_keyset = DispatchKeySet({
|
||||
DispatchKey::AutocastXPU,
|
||||
DispatchKey::AutocastIPU,
|
||||
DispatchKey::AutocastHPU,
|
||||
DispatchKey::AutocastXLA,
|
||||
DispatchKey::AutocastPrivateUse1,
|
||||
});
|
||||
|
||||
@ -660,6 +661,7 @@ constexpr DispatchKeySet default_excluded_set = DispatchKeySet({
|
||||
DispatchKey::AutocastXPU,
|
||||
DispatchKey::AutocastIPU,
|
||||
DispatchKey::AutocastHPU,
|
||||
DispatchKey::AutocastXLA,
|
||||
DispatchKey::AutocastPrivateUse1,
|
||||
});
|
||||
|
||||
@ -845,6 +847,7 @@ inline DispatchKeySet getAutocastRelatedKeySetFromBackend(BackendComponent t) {
|
||||
constexpr auto autocast_ipu_ks = DispatchKeySet(DispatchKey::AutocastIPU);
|
||||
constexpr auto autocast_hpu_ks = DispatchKeySet(DispatchKey::AutocastHPU);
|
||||
constexpr auto autocast_cuda_ks = DispatchKeySet(DispatchKey::AutocastCUDA);
|
||||
constexpr auto autocast_xla_ks = DispatchKeySet(DispatchKey::AutocastXLA);
|
||||
constexpr auto autocast_privateuse1_ks =
|
||||
DispatchKeySet(DispatchKey::AutocastPrivateUse1);
|
||||
switch (t) {
|
||||
@ -857,8 +860,9 @@ inline DispatchKeySet getAutocastRelatedKeySetFromBackend(BackendComponent t) {
|
||||
case BackendComponent::HPUBit:
|
||||
return autocast_hpu_ks;
|
||||
case BackendComponent::CUDABit:
|
||||
case BackendComponent::XLABit:
|
||||
return autocast_cuda_ks;
|
||||
case BackendComponent::XLABit:
|
||||
return autocast_xla_ks;
|
||||
case BackendComponent::PrivateUse1Bit:
|
||||
return autocast_privateuse1_ks;
|
||||
default:
|
||||
|
@ -264,6 +264,10 @@ class TestPublicBindings(TestCase):
|
||||
"vitals_enabled",
|
||||
"wait",
|
||||
"Tag",
|
||||
"set_autocast_xla_enabled",
|
||||
"set_autocast_xla_dtype",
|
||||
"get_autocast_xla_dtype",
|
||||
"is_autocast_xla_enabled",
|
||||
}
|
||||
torch_C_bindings = {elem for elem in dir(torch._C) if not elem.startswith("_")}
|
||||
|
||||
|
@ -200,6 +200,8 @@ class autocast:
|
||||
self.fast_dtype = torch.get_autocast_ipu_dtype() # type: ignore[attr-defined]
|
||||
elif self.device == 'hpu':
|
||||
self.fast_dtype = torch.hpu.get_autocast_hpu_dtype() # type: ignore[attr-defined]
|
||||
elif self.device == 'xla':
|
||||
self.fast_dtype = torch.get_autocast_xla_dtype() # type: ignore[attr-defined]
|
||||
elif self.device == self.custom_backend_name:
|
||||
necessary_funcs = ['is_autocast_enabled', 'set_autocast_enabled', 'get_autocast_dtype',
|
||||
'set_autocast_dtype', 'get_amp_supported_dtype']
|
||||
@ -266,6 +268,13 @@ class autocast:
|
||||
elif self.device == 'cuda':
|
||||
if enabled and self.fast_dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported():
|
||||
raise RuntimeError('Current CUDA Device does not support bfloat16. Please switch dtype to float16.')
|
||||
elif self.device == 'xla':
|
||||
supported_dtype = [torch.bfloat16]
|
||||
if self.fast_dtype not in supported_dtype:
|
||||
error_message = 'In XLA autocast, but the target dtype is not supported. Disabling autocast.\n'
|
||||
error_message += 'XLA Autocast only supports dtype of torch.bfloat16 currently.'
|
||||
warnings.warn(error_message)
|
||||
enabled = False
|
||||
self._enabled = enabled
|
||||
|
||||
def __enter__(self):
|
||||
@ -298,6 +307,12 @@ class autocast:
|
||||
torch.hpu.set_autocast_hpu_enabled(self._enabled) # type: ignore[attr-defined]
|
||||
torch.hpu.set_autocast_hpu_dtype(self.fast_dtype) # type: ignore[attr-defined]
|
||||
torch.autocast_increment_nesting()
|
||||
elif self.device == 'xla':
|
||||
self.prev = torch.is_autocast_xla_enabled() # type: ignore[attr-defined]
|
||||
self.prev_fastdtype = torch.get_autocast_xla_dtype() # type: ignore[attr-defined]
|
||||
torch.set_autocast_xla_enabled(self._enabled) # type: ignore[attr-defined]
|
||||
torch.set_autocast_xla_dtype(self.fast_dtype) # type: ignore[attr-defined]
|
||||
torch.autocast_increment_nesting()
|
||||
elif self.device == self.custom_backend_name:
|
||||
self.prev = self.custom_device_mod.is_autocast_enabled()
|
||||
self.prev_fastdtype = self.custom_device_mod.get_autocast_dtype()
|
||||
@ -337,6 +352,11 @@ class autocast:
|
||||
torch.clear_autocast_cache()
|
||||
torch.hpu.set_autocast_hpu_enabled(self.prev) # type: ignore[attr-defined]
|
||||
torch.hpu.set_autocast_hpu_dtype(self.prev_fastdtype) # type: ignore[attr-defined]
|
||||
elif self.device == 'xla':
|
||||
if torch.autocast_decrement_nesting() == 0:
|
||||
torch.clear_autocast_cache()
|
||||
torch.set_autocast_xla_enabled(self.prev) # type: ignore[attr-defined]
|
||||
torch.set_autocast_xla_dtype(self.prev_fastdtype) # type: ignore[attr-defined]
|
||||
elif self.device == self.custom_backend_name:
|
||||
if torch.autocast_decrement_nesting() == 0:
|
||||
torch.clear_autocast_cache()
|
||||
|
@ -478,7 +478,8 @@ static PyObject* is_autocast_enabled(PyObject* _unused, PyObject* arg) {
|
||||
static PyObject* is_any_autocast_enabled(PyObject* _unused, PyObject* arg) {
|
||||
HANDLE_TH_ERRORS
|
||||
if (at::autocast::is_enabled() || at::autocast::is_cpu_enabled() ||
|
||||
at::autocast::is_xpu_enabled() || at::autocast::is_ipu_enabled()) {
|
||||
at::autocast::is_xpu_enabled() || at::autocast::is_ipu_enabled() ||
|
||||
at::autocast::is_xla_enabled()) {
|
||||
Py_RETURN_TRUE;
|
||||
} else {
|
||||
Py_RETURN_FALSE;
|
||||
@ -526,6 +527,26 @@ static PyObject* is_autocast_ipu_enabled(PyObject* _unused, PyObject* arg) {
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* set_autocast_xla_enabled(PyObject* _unused, PyObject* arg) {
|
||||
HANDLE_TH_ERRORS
|
||||
if (!PyBool_Check(arg)) {
|
||||
throw TypeError("enabled must be a bool (got %s)", Py_TYPE(arg)->tp_name);
|
||||
}
|
||||
at::autocast::set_xla_enabled(arg == Py_True);
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* is_autocast_xla_enabled(PyObject* _unused, PyObject* arg) {
|
||||
HANDLE_TH_ERRORS
|
||||
if (at::autocast::is_xla_enabled()) {
|
||||
Py_RETURN_TRUE;
|
||||
} else {
|
||||
Py_RETURN_FALSE;
|
||||
}
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* set_autocast_gpu_dtype(PyObject* _unused, PyObject* arg) {
|
||||
HANDLE_TH_ERRORS
|
||||
if (!THPDtype_Check(arg)) {
|
||||
@ -562,6 +583,18 @@ static PyObject* set_autocast_ipu_dtype(PyObject* _unused, PyObject* arg) {
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* set_autocast_xla_dtype(PyObject* _unused, PyObject* arg) {
|
||||
HANDLE_TH_ERRORS
|
||||
if (!THPDtype_Check(arg)) {
|
||||
throw TypeError(
|
||||
"dtype must be a torch.dtype (got %s)", Py_TYPE(arg)->tp_name);
|
||||
}
|
||||
at::ScalarType targetType = reinterpret_cast<THPDtype*>(arg)->scalar_type;
|
||||
at::autocast::set_autocast_xla_dtype(targetType);
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* get_autocast_gpu_dtype(PyObject* _unused, PyObject* arg) {
|
||||
HANDLE_TH_ERRORS
|
||||
at::ScalarType current_dtype = at::autocast::get_autocast_gpu_dtype();
|
||||
@ -589,6 +622,15 @@ static PyObject* get_autocast_ipu_dtype(PyObject* _unused, PyObject* arg) {
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* get_autocast_xla_dtype(PyObject* _unused, PyObject* arg) {
|
||||
HANDLE_TH_ERRORS
|
||||
at::ScalarType current_dtype = at::autocast::get_autocast_xla_dtype();
|
||||
auto dtype = (PyObject*)torch::getTHPDtype(current_dtype);
|
||||
Py_INCREF(dtype);
|
||||
return dtype;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* clear_autocast_cache(PyObject* _unused, PyObject* arg) {
|
||||
HANDLE_TH_ERRORS
|
||||
at::autocast::clear_cache();
|
||||
@ -884,6 +926,10 @@ static PyMethodDef methods[] = { // NOLINT
|
||||
{"get_autocast_cpu_dtype", get_autocast_cpu_dtype, METH_NOARGS, nullptr},
|
||||
{"set_autocast_gpu_dtype", set_autocast_gpu_dtype, METH_O, nullptr},
|
||||
{"get_autocast_gpu_dtype", get_autocast_gpu_dtype, METH_NOARGS, nullptr},
|
||||
{"set_autocast_xla_enabled", set_autocast_xla_enabled, METH_O, nullptr},
|
||||
{"is_autocast_xla_enabled", is_autocast_xla_enabled, METH_NOARGS, nullptr},
|
||||
{"set_autocast_xla_dtype", set_autocast_xla_dtype, METH_O, nullptr},
|
||||
{"get_autocast_xla_dtype", get_autocast_xla_dtype, METH_NOARGS, nullptr},
|
||||
{"set_autocast_ipu_enabled", set_autocast_ipu_enabled, METH_O, nullptr},
|
||||
{"is_autocast_ipu_enabled", is_autocast_ipu_enabled, METH_NOARGS, nullptr},
|
||||
{"set_autocast_ipu_dtype", set_autocast_ipu_dtype, METH_O, nullptr},
|
||||
|
@ -226,6 +226,8 @@ def get_ignored_functions() -> Set[Callable]:
|
||||
torch.clear_autocast_cache,
|
||||
torch.set_autocast_cpu_enabled,
|
||||
torch.is_autocast_cpu_enabled,
|
||||
torch.set_autocast_xla_enabled,
|
||||
torch.is_autocast_xla_enabled,
|
||||
torch.set_autocast_ipu_enabled,
|
||||
torch.is_autocast_ipu_enabled,
|
||||
torch.set_autocast_cpu_dtype,
|
||||
@ -234,6 +236,8 @@ def get_ignored_functions() -> Set[Callable]:
|
||||
torch.get_autocast_ipu_dtype,
|
||||
torch.get_autocast_gpu_dtype,
|
||||
torch.set_autocast_gpu_dtype,
|
||||
torch.get_autocast_xla_dtype,
|
||||
torch.set_autocast_xla_dtype,
|
||||
torch.autocast_increment_nesting,
|
||||
torch.autocast_decrement_nesting,
|
||||
torch.is_autocast_cache_enabled,
|
||||
|
Reference in New Issue
Block a user