add amp support for custom backend (#96188)

Fixes #ISSUE_NUMBER
1、add amp support for custom backend
2、optimize the file `backend_registration.py`, and rename it with `custom_backend_registration.py`. And then we would register other funcs for custom backend.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96188
Approved by: https://github.com/bdhirsh
This commit is contained in:
shibo
2023-03-20 20:27:35 +00:00
committed by PyTorch MergeBot
parent a37b4fa03a
commit 6b691b99da
10 changed files with 153 additions and 0 deletions

View File

@ -46,6 +46,14 @@ void set_hpu_enabled(bool new_enabled) {
c10::impl::tls_set_dispatch_key_excluded(DispatchKey::AutocastHPU, !new_enabled);
}
bool is_privateuseone_enabled() {
return !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::AutocastPrivateUse1);
}
void set_privateuseone_enabled(bool new_enabled) {
c10::impl::tls_set_dispatch_key_excluded(DispatchKey::AutocastPrivateUse1, !new_enabled);
}
namespace {
// Imitate Apex and cache some of the casts to streamline parameter reuse.
// Our heuristic is to cache lower_precision_fp casts of fp32 model weights (see cached_cast below).
@ -88,6 +96,9 @@ thread_local bool cache_enabled = true;
// autocast_gpu_dtype is the lower_precision_fp used by AutocastGPU.
thread_local at::ScalarType autocast_gpu_dtype = at::kHalf;
// autocast_privateuseone_dtype is the lower_precision_fp used by AutocastPrivateUse1.
thread_local at::ScalarType autocast_privateuseone_dtype = at::kHalf;
}
void clear_cache() {
@ -119,6 +130,10 @@ at::ScalarType get_autocast_hpu_dtype() {
return autocast_hpu_dtype;
}
at::ScalarType get_autocast_privateuseone_dtype() {
return autocast_privateuseone_dtype;
}
void set_autocast_cpu_dtype(at::ScalarType dtype) {
TORCH_CHECK(
dtype == at::kBFloat16,
@ -138,6 +153,10 @@ void set_autocast_hpu_dtype(at::ScalarType dtype) {
autocast_hpu_dtype = dtype;
}
void set_autocast_privateuseone_dtype(at::ScalarType dtype) {
autocast_privateuseone_dtype = dtype;
}
bool is_autocast_cache_enabled() {
return cache_enabled;
}

View File

@ -24,6 +24,10 @@ TORCH_API bool is_hpu_enabled();
TORCH_API void set_hpu_enabled(bool enabled);
TORCH_API at::ScalarType get_autocast_hpu_dtype();
TORCH_API void set_autocast_hpu_dtype(at::ScalarType dtype);
TORCH_API bool is_privateuseone_enabled();
TORCH_API void set_privateuseone_enabled(bool enabled);
TORCH_API at::ScalarType get_autocast_privateuseone_dtype();
TORCH_API void set_autocast_privateuseone_dtype(at::ScalarType dtype);
TORCH_API bool is_autocast_cache_enabled();
TORCH_API void set_autocast_cache_enabled(bool enabled);
@ -40,6 +44,9 @@ bool is_autocast_eligible(const Tensor& tensor, DeviceType device_type) {
return tensor.is_xpu() && tensor.is_floating_point();
case DeviceType::HPU:
return tensor.is_hpu() && tensor.is_floating_point();
case DeviceType::PrivateUse1:
return tensor.device().type() == DeviceType::PrivateUse1 &&
tensor.is_floating_point();
default:
return false;
}
@ -57,6 +64,8 @@ inline DispatchKey get_autocast_dispatch_key_from_device_type(
return DispatchKey::AutocastXPU;
case DeviceType::HPU:
return DispatchKey::AutocastHPU;
case DeviceType::PrivateUse1:
return DispatchKey::AutocastPrivateUse1;
default:
throw std::runtime_error(
"unknown device type for autocast in get_autocast_dispatch_key_from_device_type");
@ -74,6 +83,8 @@ inline at::ScalarType get_lower_precision_fp_from_device_type(
return get_autocast_xpu_dtype();
case DeviceType::HPU:
return get_autocast_hpu_dtype();
case DeviceType::PrivateUse1:
return get_autocast_privateuseone_dtype();
default:
throw std::runtime_error(
"unknown device type for autocast in get_lower_precision_fp_from_device_type");

View File

@ -144,6 +144,8 @@ const char* toString(DispatchKey t) {
return "AutocastHPU";
case DispatchKey::AutocastCUDA:
return "AutocastCUDA";
case DispatchKey::AutocastPrivateUse1:
return "AutocastPrivateUse1";
case DispatchKey::FuncTorchBatched:
return "FuncTorchBatched";
@ -285,6 +287,7 @@ c10::DispatchKey parseDispatchKey(const std::string& k) {
{"AutocastXPU", c10::DispatchKey::AutocastXPU},
{"AutocastHPU", c10::DispatchKey::AutocastHPU},
{"AutocastCUDA", c10::DispatchKey::AutocastCUDA},
{"AutocastPrivateUse1", c10::DispatchKey::AutocastPrivateUse1},
{"FuncTorchBatched", c10::DispatchKey::FuncTorchBatched},
{"FuncTorchVmapMode", c10::DispatchKey::FuncTorchVmapMode},
{"Batched", c10::DispatchKey::Batched},

View File

@ -356,6 +356,7 @@ enum class DispatchKey : uint16_t {
// Naughtily, AutocastCUDA is also being used for XLA. In the terminal state,
// it probably should get its own Autocast key
AutocastCUDA,
AutocastPrivateUse1,
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~ WRAPPERS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ //
// There are a number of alternative modes which may want to handle before

View File

@ -643,6 +643,7 @@ constexpr DispatchKeySet autocast_dispatch_keyset = DispatchKeySet({
DispatchKey::AutocastCUDA,
DispatchKey::AutocastXPU,
DispatchKey::AutocastHPU,
DispatchKey::AutocastPrivateUse1,
});
// See Note [TLS Initialization]
@ -656,6 +657,7 @@ constexpr DispatchKeySet default_excluded_set = DispatchKeySet({
DispatchKey::AutocastCUDA,
DispatchKey::AutocastXPU,
DispatchKey::AutocastHPU,
DispatchKey::AutocastPrivateUse1,
});
constexpr DispatchKeySet autograd_dispatch_keyset_with_ADInplaceOrView =
@ -839,6 +841,8 @@ inline DispatchKeySet getAutocastRelatedKeySetFromBackend(BackendComponent t) {
constexpr auto autocast_xpu_ks = DispatchKeySet(DispatchKey::AutocastXPU);
constexpr auto autocast_hpu_ks = DispatchKeySet(DispatchKey::AutocastHPU);
constexpr auto autocast_cuda_ks = DispatchKeySet(DispatchKey::AutocastCUDA);
constexpr auto autocast_privateuse1_ks =
DispatchKeySet(DispatchKey::AutocastPrivateUse1);
switch (t) {
case BackendComponent::CPUBit:
return autocast_cpu_ks;
@ -849,6 +853,8 @@ inline DispatchKeySet getAutocastRelatedKeySetFromBackend(BackendComponent t) {
case BackendComponent::CUDABit:
case BackendComponent::XLABit:
return autocast_cuda_ks;
case BackendComponent::PrivateUse1Bit:
return autocast_privateuse1_ks;
default:
return DispatchKeySet();
}

View File

@ -784,6 +784,26 @@ class DummyXPUModule:
def is_available():
return True
@staticmethod
def is_autocast_foo_enabled():
return True
@staticmethod
def get_autocast_foo_dtype():
return torch.float16
@staticmethod
def set_autocast_foo_enabled(enable):
pass
@staticmethod
def set_autocast_foo_dtype(dtype):
pass
@staticmethod
def get_amp_supported_dtype():
return [torch.float16]
class TestExtensionUtils(TestCase):
def test_external_module_register(self):
@ -806,6 +826,26 @@ class TestExtensionUtils(TestCase):
with self.assertRaisesRegex(RuntimeError, "The runtime module of"):
torch._register_device_module('xpu', DummyXPUModule)
def test_external_module_and_backend_register(self):
torch.utils.rename_privateuse1_backend('foo')
with self.assertRaisesRegex(RuntimeError, "has already been set"):
torch.utils.rename_privateuse1_backend('dummmy')
custom_backend_name = torch._C._get_privateuse1_backend_name()
self.assertEqual(custom_backend_name, 'foo')
with self.assertRaises(AttributeError):
torch.foo.is_available()
with self.assertRaisesRegex(AssertionError, "Tried to use AMP with the"):
with torch.autocast(device_type=custom_backend_name):
pass
torch._register_device_module('foo', DummyXPUModule)
torch.foo.is_available()
with torch.autocast(device_type=custom_backend_name):
pass
class TestDeviceUtils(TestCase):
def test_basic(self):

View File

@ -1233,6 +1233,7 @@ def _get_custom_class_python_wrapper(name: str, attr: str) -> Any: ...
# Defined in torch/csrc/Module.cpp
def _rename_privateuse1_backend(backend: str) -> None: ...
def _get_privateuse1_backend_name() -> str: ...
# Defined in torch/csrc/Generator.cpp
class Generator:

View File

@ -189,6 +189,7 @@ class autocast:
assert dtype is not None
return
self.device = device_type
self.custom_backend_name = torch._C._get_privateuse1_backend_name()
if self.device == 'cuda':
self.fast_dtype = torch.get_autocast_gpu_dtype()
elif self.device == 'cpu':
@ -197,6 +198,24 @@ class autocast:
self.fast_dtype = torch.xpu.get_autocast_xpu_dtype() # type: ignore[attr-defined]
elif self.device == 'hpu':
self.fast_dtype = torch.hpu.get_autocast_hpu_dtype() # type: ignore[attr-defined]
elif self.device == self.custom_backend_name:
name_ = self.custom_backend_name
necessary_funcs = [f'is_autocast_{name_}_enabled', f'set_autocast_{name_}_enabled',
f'get_autocast_{name_}_dtype', f'set_autocast_{name_}_dtype',
'get_amp_supported_dtype']
message = f"Tried to use AMP with the `{name_}` backend, but the backend has not registered a module or "
message += "the module miss some necessary funcs. The backend should register a corresponding module "
message += "`torch._register_device_module`, and the module must have these funcs: \n"
message += f"`is_autocast_{name_}_enabled() -> bool`, `set_autocast_{name_}_enabled(bool) -> None`, "
message += f"`get_autocast_{name_}_dtype() -> torch.dtype`, `set_autocast_{name_}_dtype(torch.dtype) "
message += "-> None` and `get_amp_supported_dtype() -> List[torch.dtype]`. \n"
assert hasattr(torch, name_), message
self.custom_device_mod = getattr(torch, name_)
for func in necessary_funcs:
assert hasattr(self.custom_device_mod, func), message + f"But the func `{func}` is missing. \n"
self.fast_dtype = getattr(self.custom_device_mod, f'get_autocast_{name_}_dtype')()
else:
raise RuntimeError('User specified autocast device_type must be \'cuda\' or \'cpu\'')
self._cache_enabled = torch.is_autocast_cache_enabled()
@ -229,6 +248,14 @@ class autocast:
error_message += 'HPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently.'
warnings.warn(error_message)
enabled = False
elif self.device == self.custom_backend_name:
supported_dtype = self.custom_device_mod.get_amp_supported_dtype()
if self.fast_dtype not in supported_dtype:
error_message = f"In {self.custom_backend_name} autocast, but the target dtype is not supported. "
error_message += f"Disabling autocast.\n {self.custom_backend_name} Autocast only supports dtypes of "
error_message += ", ".join(str(dtype) for dtype in supported_dtype) + " currently."
warnings.warn(error_message)
enabled = False
elif self.device == 'cuda':
if self.fast_dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported():
raise RuntimeError('Current CUDA Device does not support bfloat16. Please switch dtype to float16.')
@ -258,6 +285,12 @@ class autocast:
torch.hpu.set_autocast_hpu_enabled(self._enabled) # type: ignore[attr-defined]
torch.hpu.set_autocast_hpu_dtype(self.fast_dtype) # type: ignore[attr-defined]
torch.autocast_increment_nesting()
elif self.device == self.custom_backend_name:
self.prev = getattr(self.custom_device_mod, f'is_autocast_{self.custom_backend_name}_enabled')()
self.prev_fastdtype = getattr(self.custom_device_mod, f'get_autocast_{self.custom_backend_name}_dtype')()
getattr(self.custom_device_mod, f'set_autocast_{self.custom_backend_name}_enabled')(self._enabled)
getattr(self.custom_device_mod, f'set_autocast_{self.custom_backend_name}_dtype')(self.fast_dtype)
torch.autocast_increment_nesting()
else:
self.prev = torch.is_autocast_enabled()
self.prev_fastdtype = torch.get_autocast_gpu_dtype()
@ -286,6 +319,11 @@ class autocast:
torch.clear_autocast_cache()
torch.hpu.set_autocast_hpu_enabled(self.prev) # type: ignore[attr-defined]
torch.hpu.set_autocast_hpu_dtype(self.prev_fastdtype) # type: ignore[attr-defined]
elif self.device == self.custom_backend_name:
if torch.autocast_decrement_nesting() == 0:
torch.clear_autocast_cache()
getattr(self.custom_device_mod, f'set_autocast_{self.custom_backend_name}_enabled')(self._enabled)
getattr(self.custom_device_mod, f'set_autocast_{self.custom_backend_name}_dtype')(self.fast_dtype)
else:
if torch.autocast_decrement_nesting() == 0:
torch.clear_autocast_cache()

View File

@ -464,6 +464,7 @@ PyObject* THModule_getCppBacktrace(PyObject* _unused, PyObject* args) {
c10::get_backtrace(frames_to_skip, maximum_number_of_frames, true));
END_HANDLE_TH_ERRORS
}
static PyObject* THModule_rename_privateuse1_backend(
PyObject* _unused,
PyObject* arg) {
@ -479,6 +480,14 @@ static PyObject* THModule_rename_privateuse1_backend(
END_HANDLE_TH_ERRORS
}
static PyObject* THModule_get_privateuse1_backend_name(
PyObject* _unused,
PyObject* arg) {
HANDLE_TH_ERRORS
return THPUtils_packString(c10::get_privateuse1_backend());
END_HANDLE_TH_ERRORS
}
PyObject* THPModule_setAllowTF32CuDNN(PyObject* _unused, PyObject* arg) {
THPUtils_assert(
PyBool_Check(arg),
@ -1134,6 +1143,10 @@ static PyMethodDef TorchMethods[] = {
THModule_rename_privateuse1_backend,
METH_O,
nullptr},
{"_get_privateuse1_backend_name",
THModule_get_privateuse1_backend_name,
METH_NOARGS,
nullptr},
{"set_flush_denormal", THPModule_setFlushDenormal, METH_O, nullptr},
{"get_default_dtype", THPModule_getDefaultDtype, METH_NOARGS, nullptr},
{"_get_default_device", THPModule_getDefaultDevice, METH_NOARGS, nullptr},

View File

@ -17,6 +17,27 @@ def rename_privateuse1_backend(backend_name: str) -> None:
Note: this API can only be called once per process. Attempting to change
the external backend after it's already been set will result in an error.
Note: and if you want to support AMP on your device, you can register a custom backend module.
The backend must register a custom backend module with `torch._register_device_module("foo", BackendModule)`.
BackendModule needs to have the following API's:
(1) get_amp_supported_dtype() -> List[torch.dtype]
get the supported dtypes on your `foo` device in AMP, maybe the `foo` device supports one more dtype.
(2) is_autocast_foo_enabled() -> bool
check the AMP is enabled or not on your `foo` device.
(3) get_autocast_foo_dtype() -> torch.dtype
get the supported dtype on your `foo` device in AMP, which is set by `set_autocast_foo_dtype` or the
default dtype, and the default dtype is `torch.float16`.
(4) set_autocast_foo_enabled(bool) -> None
enable the AMP or not on your `foo` device.
(5) set_autocast_foo_dtype(dtype) -> None
set the supported dtype on your `foo` device in AMP, and the dtype be contained in the dtypes got
from `get_amp_supported_dtype`.
For more details, see https://pytorch.org/tutorials/advanced/extend_dispatcher.html#get-a-dispatch-key-for-your-backend
For an existing example, see https://github.com/bdhirsh/pytorch_open_registration_example