[AMP] Support XLA:TPU (#96370)

With https://github.com/pytorch/xla/pull/5148, https://github.com/pytorch/xla/pull/4740

With these changes
XLA:GPU users should use `torch.cuda.amp.autocast()` for AMP with float16
XLA:TPU users should use `torch.amp.autocast('xla')` for AMP with bfloat16

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96370
Approved by: https://github.com/bdhirsh, https://github.com/malfet
This commit is contained in:
Meghan
2023-06-23 19:46:42 +00:00
committed by PyTorch MergeBot
parent c17bdb3247
commit 6ff4548b6e
11 changed files with 117 additions and 5 deletions

View File

@ -1 +1 @@
0eeaefb2341c3beea65545f278cdbd998f5a8399
73392fc2a6c9ec40cba968ea66754514346ac79f

View File

@ -49,6 +49,14 @@ void set_hpu_enabled(bool new_enabled) {
c10::impl::tls_set_dispatch_key_excluded(DispatchKey::AutocastHPU, !new_enabled);
}
bool is_xla_enabled() {
return !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::AutocastXLA);
}
void set_xla_enabled(bool new_enabled) {
c10::impl::tls_set_dispatch_key_excluded(DispatchKey::AutocastXLA, !new_enabled);
}
bool is_privateuseone_enabled() {
return !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::AutocastPrivateUse1);
}
@ -98,6 +106,9 @@ thread_local at::ScalarType autocast_ipu_dtype = at::kHalf;
// autocast_hpu_dtype is the lower_precision_fp used by AutocastHPU.
thread_local at::ScalarType autocast_hpu_dtype = at::kBFloat16;
// autocast_xla_dtype is the lower_precision_fp used by AutocastXLA.
thread_local at::ScalarType autocast_xla_dtype = at::kBFloat16;
// should we enabled the cache inside autocast.
thread_local bool cache_enabled = true;
@ -141,6 +152,10 @@ at::ScalarType get_autocast_hpu_dtype() {
return autocast_hpu_dtype;
}
at::ScalarType get_autocast_xla_dtype() {
return autocast_xla_dtype;
}
at::ScalarType get_autocast_privateuseone_dtype() {
return autocast_privateuseone_dtype;
}
@ -168,6 +183,10 @@ void set_autocast_hpu_dtype(at::ScalarType dtype) {
autocast_hpu_dtype = dtype;
}
void set_autocast_xla_dtype(at::ScalarType dtype) {
autocast_xla_dtype = dtype;
}
void set_autocast_privateuseone_dtype(at::ScalarType dtype) {
autocast_privateuseone_dtype = dtype;
}

View File

@ -34,6 +34,10 @@ TORCH_API bool is_hpu_enabled();
TORCH_API void set_hpu_enabled(bool enabled);
TORCH_API at::ScalarType get_autocast_hpu_dtype();
TORCH_API void set_autocast_hpu_dtype(at::ScalarType dtype);
TORCH_API bool is_xla_enabled();
TORCH_API void set_xla_enabled(bool enabled);
TORCH_API at::ScalarType get_autocast_xla_dtype();
TORCH_API void set_autocast_xla_dtype(at::ScalarType dtype);
TORCH_API bool is_privateuseone_enabled();
TORCH_API void set_privateuseone_enabled(bool enabled);
TORCH_API at::ScalarType get_autocast_privateuseone_dtype();
@ -56,6 +60,8 @@ bool is_autocast_eligible(const Tensor& tensor, DeviceType device_type) {
return tensor.is_ipu() && tensor.is_floating_point();
case DeviceType::HPU:
return tensor.is_hpu() && tensor.is_floating_point();
case DeviceType::XLA:
return tensor.is_xla() && tensor.is_floating_point();
case DeviceType::PrivateUse1:
return tensor.device().type() == DeviceType::PrivateUse1 &&
tensor.is_floating_point();
@ -78,6 +84,8 @@ inline DispatchKey get_autocast_dispatch_key_from_device_type(
return DispatchKey::AutocastIPU;
case DeviceType::HPU:
return DispatchKey::AutocastHPU;
case DeviceType::XLA:
return DispatchKey::AutocastXLA;
case DeviceType::PrivateUse1:
return DispatchKey::AutocastPrivateUse1;
default:
@ -99,6 +107,8 @@ inline at::ScalarType get_lower_precision_fp_from_device_type(
return get_autocast_ipu_dtype();
case DeviceType::HPU:
return get_autocast_hpu_dtype();
case DeviceType::XLA:
return get_autocast_xla_dtype();
case DeviceType::PrivateUse1:
return get_autocast_privateuseone_dtype();
default:

View File

@ -231,6 +231,7 @@ namespace c10 {
_(aten, has_torch_function) \
_(aten, is_autocast_enabled) \
_(aten, is_autocast_cpu_enabled) \
_(aten, is_autocast_xla_enabled) \
FORALL_ATEN_BASE_SYMBOLS(_) \
_(onnx, Add) \
_(onnx, Concat) \

View File

@ -146,6 +146,8 @@ const char* toString(DispatchKey t) {
return "AutocastHPU";
case DispatchKey::AutocastCUDA:
return "AutocastCUDA";
case DispatchKey::AutocastXLA:
return "AutocastXLA";
case DispatchKey::AutocastPrivateUse1:
return "AutocastPrivateUse1";
@ -293,6 +295,7 @@ c10::DispatchKey parseDispatchKey(const std::string& k) {
{"AutocastIPU", c10::DispatchKey::AutocastIPU},
{"AutocastHPU", c10::DispatchKey::AutocastHPU},
{"AutocastCUDA", c10::DispatchKey::AutocastCUDA},
{"AutocastXLA", c10::DispatchKey::AutocastXLA},
{"AutocastPrivateUse1", c10::DispatchKey::AutocastPrivateUse1},
{"FuncTorchBatched", c10::DispatchKey::FuncTorchBatched},
{"FuncTorchVmapMode", c10::DispatchKey::FuncTorchVmapMode},

View File

@ -355,8 +355,9 @@ enum class DispatchKey : uint16_t {
AutocastXPU,
AutocastIPU,
AutocastHPU,
// Naughtily, AutocastCUDA is also being used for XLA. In the terminal state,
// it probably should get its own Autocast key
AutocastXLA,
// AutocastXLA is only being used for TPUs. XLA GPUs continue to use
// AutocastCUDA.
AutocastCUDA,
AutocastPrivateUse1,

View File

@ -645,6 +645,7 @@ constexpr DispatchKeySet autocast_dispatch_keyset = DispatchKeySet({
DispatchKey::AutocastXPU,
DispatchKey::AutocastIPU,
DispatchKey::AutocastHPU,
DispatchKey::AutocastXLA,
DispatchKey::AutocastPrivateUse1,
});
@ -660,6 +661,7 @@ constexpr DispatchKeySet default_excluded_set = DispatchKeySet({
DispatchKey::AutocastXPU,
DispatchKey::AutocastIPU,
DispatchKey::AutocastHPU,
DispatchKey::AutocastXLA,
DispatchKey::AutocastPrivateUse1,
});
@ -845,6 +847,7 @@ inline DispatchKeySet getAutocastRelatedKeySetFromBackend(BackendComponent t) {
constexpr auto autocast_ipu_ks = DispatchKeySet(DispatchKey::AutocastIPU);
constexpr auto autocast_hpu_ks = DispatchKeySet(DispatchKey::AutocastHPU);
constexpr auto autocast_cuda_ks = DispatchKeySet(DispatchKey::AutocastCUDA);
constexpr auto autocast_xla_ks = DispatchKeySet(DispatchKey::AutocastXLA);
constexpr auto autocast_privateuse1_ks =
DispatchKeySet(DispatchKey::AutocastPrivateUse1);
switch (t) {
@ -857,8 +860,9 @@ inline DispatchKeySet getAutocastRelatedKeySetFromBackend(BackendComponent t) {
case BackendComponent::HPUBit:
return autocast_hpu_ks;
case BackendComponent::CUDABit:
case BackendComponent::XLABit:
return autocast_cuda_ks;
case BackendComponent::XLABit:
return autocast_xla_ks;
case BackendComponent::PrivateUse1Bit:
return autocast_privateuse1_ks;
default:

View File

@ -264,6 +264,10 @@ class TestPublicBindings(TestCase):
"vitals_enabled",
"wait",
"Tag",
"set_autocast_xla_enabled",
"set_autocast_xla_dtype",
"get_autocast_xla_dtype",
"is_autocast_xla_enabled",
}
torch_C_bindings = {elem for elem in dir(torch._C) if not elem.startswith("_")}

View File

@ -200,6 +200,8 @@ class autocast:
self.fast_dtype = torch.get_autocast_ipu_dtype() # type: ignore[attr-defined]
elif self.device == 'hpu':
self.fast_dtype = torch.hpu.get_autocast_hpu_dtype() # type: ignore[attr-defined]
elif self.device == 'xla':
self.fast_dtype = torch.get_autocast_xla_dtype() # type: ignore[attr-defined]
elif self.device == self.custom_backend_name:
necessary_funcs = ['is_autocast_enabled', 'set_autocast_enabled', 'get_autocast_dtype',
'set_autocast_dtype', 'get_amp_supported_dtype']
@ -266,6 +268,13 @@ class autocast:
elif self.device == 'cuda':
if enabled and self.fast_dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported():
raise RuntimeError('Current CUDA Device does not support bfloat16. Please switch dtype to float16.')
elif self.device == 'xla':
supported_dtype = [torch.bfloat16]
if self.fast_dtype not in supported_dtype:
error_message = 'In XLA autocast, but the target dtype is not supported. Disabling autocast.\n'
error_message += 'XLA Autocast only supports dtype of torch.bfloat16 currently.'
warnings.warn(error_message)
enabled = False
self._enabled = enabled
def __enter__(self):
@ -298,6 +307,12 @@ class autocast:
torch.hpu.set_autocast_hpu_enabled(self._enabled) # type: ignore[attr-defined]
torch.hpu.set_autocast_hpu_dtype(self.fast_dtype) # type: ignore[attr-defined]
torch.autocast_increment_nesting()
elif self.device == 'xla':
self.prev = torch.is_autocast_xla_enabled() # type: ignore[attr-defined]
self.prev_fastdtype = torch.get_autocast_xla_dtype() # type: ignore[attr-defined]
torch.set_autocast_xla_enabled(self._enabled) # type: ignore[attr-defined]
torch.set_autocast_xla_dtype(self.fast_dtype) # type: ignore[attr-defined]
torch.autocast_increment_nesting()
elif self.device == self.custom_backend_name:
self.prev = self.custom_device_mod.is_autocast_enabled()
self.prev_fastdtype = self.custom_device_mod.get_autocast_dtype()
@ -337,6 +352,11 @@ class autocast:
torch.clear_autocast_cache()
torch.hpu.set_autocast_hpu_enabled(self.prev) # type: ignore[attr-defined]
torch.hpu.set_autocast_hpu_dtype(self.prev_fastdtype) # type: ignore[attr-defined]
elif self.device == 'xla':
if torch.autocast_decrement_nesting() == 0:
torch.clear_autocast_cache()
torch.set_autocast_xla_enabled(self.prev) # type: ignore[attr-defined]
torch.set_autocast_xla_dtype(self.prev_fastdtype) # type: ignore[attr-defined]
elif self.device == self.custom_backend_name:
if torch.autocast_decrement_nesting() == 0:
torch.clear_autocast_cache()

View File

@ -478,7 +478,8 @@ static PyObject* is_autocast_enabled(PyObject* _unused, PyObject* arg) {
static PyObject* is_any_autocast_enabled(PyObject* _unused, PyObject* arg) {
HANDLE_TH_ERRORS
if (at::autocast::is_enabled() || at::autocast::is_cpu_enabled() ||
at::autocast::is_xpu_enabled() || at::autocast::is_ipu_enabled()) {
at::autocast::is_xpu_enabled() || at::autocast::is_ipu_enabled() ||
at::autocast::is_xla_enabled()) {
Py_RETURN_TRUE;
} else {
Py_RETURN_FALSE;
@ -526,6 +527,26 @@ static PyObject* is_autocast_ipu_enabled(PyObject* _unused, PyObject* arg) {
END_HANDLE_TH_ERRORS
}
static PyObject* set_autocast_xla_enabled(PyObject* _unused, PyObject* arg) {
HANDLE_TH_ERRORS
if (!PyBool_Check(arg)) {
throw TypeError("enabled must be a bool (got %s)", Py_TYPE(arg)->tp_name);
}
at::autocast::set_xla_enabled(arg == Py_True);
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
static PyObject* is_autocast_xla_enabled(PyObject* _unused, PyObject* arg) {
HANDLE_TH_ERRORS
if (at::autocast::is_xla_enabled()) {
Py_RETURN_TRUE;
} else {
Py_RETURN_FALSE;
}
END_HANDLE_TH_ERRORS
}
static PyObject* set_autocast_gpu_dtype(PyObject* _unused, PyObject* arg) {
HANDLE_TH_ERRORS
if (!THPDtype_Check(arg)) {
@ -562,6 +583,18 @@ static PyObject* set_autocast_ipu_dtype(PyObject* _unused, PyObject* arg) {
END_HANDLE_TH_ERRORS
}
static PyObject* set_autocast_xla_dtype(PyObject* _unused, PyObject* arg) {
HANDLE_TH_ERRORS
if (!THPDtype_Check(arg)) {
throw TypeError(
"dtype must be a torch.dtype (got %s)", Py_TYPE(arg)->tp_name);
}
at::ScalarType targetType = reinterpret_cast<THPDtype*>(arg)->scalar_type;
at::autocast::set_autocast_xla_dtype(targetType);
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
static PyObject* get_autocast_gpu_dtype(PyObject* _unused, PyObject* arg) {
HANDLE_TH_ERRORS
at::ScalarType current_dtype = at::autocast::get_autocast_gpu_dtype();
@ -589,6 +622,15 @@ static PyObject* get_autocast_ipu_dtype(PyObject* _unused, PyObject* arg) {
END_HANDLE_TH_ERRORS
}
static PyObject* get_autocast_xla_dtype(PyObject* _unused, PyObject* arg) {
HANDLE_TH_ERRORS
at::ScalarType current_dtype = at::autocast::get_autocast_xla_dtype();
auto dtype = (PyObject*)torch::getTHPDtype(current_dtype);
Py_INCREF(dtype);
return dtype;
END_HANDLE_TH_ERRORS
}
static PyObject* clear_autocast_cache(PyObject* _unused, PyObject* arg) {
HANDLE_TH_ERRORS
at::autocast::clear_cache();
@ -884,6 +926,10 @@ static PyMethodDef methods[] = { // NOLINT
{"get_autocast_cpu_dtype", get_autocast_cpu_dtype, METH_NOARGS, nullptr},
{"set_autocast_gpu_dtype", set_autocast_gpu_dtype, METH_O, nullptr},
{"get_autocast_gpu_dtype", get_autocast_gpu_dtype, METH_NOARGS, nullptr},
{"set_autocast_xla_enabled", set_autocast_xla_enabled, METH_O, nullptr},
{"is_autocast_xla_enabled", is_autocast_xla_enabled, METH_NOARGS, nullptr},
{"set_autocast_xla_dtype", set_autocast_xla_dtype, METH_O, nullptr},
{"get_autocast_xla_dtype", get_autocast_xla_dtype, METH_NOARGS, nullptr},
{"set_autocast_ipu_enabled", set_autocast_ipu_enabled, METH_O, nullptr},
{"is_autocast_ipu_enabled", is_autocast_ipu_enabled, METH_NOARGS, nullptr},
{"set_autocast_ipu_dtype", set_autocast_ipu_dtype, METH_O, nullptr},

View File

@ -226,6 +226,8 @@ def get_ignored_functions() -> Set[Callable]:
torch.clear_autocast_cache,
torch.set_autocast_cpu_enabled,
torch.is_autocast_cpu_enabled,
torch.set_autocast_xla_enabled,
torch.is_autocast_xla_enabled,
torch.set_autocast_ipu_enabled,
torch.is_autocast_ipu_enabled,
torch.set_autocast_cpu_dtype,
@ -234,6 +236,8 @@ def get_ignored_functions() -> Set[Callable]:
torch.get_autocast_ipu_dtype,
torch.get_autocast_gpu_dtype,
torch.set_autocast_gpu_dtype,
torch.get_autocast_xla_dtype,
torch.set_autocast_xla_dtype,
torch.autocast_increment_nesting,
torch.autocast_decrement_nesting,
torch.is_autocast_cache_enabled,