mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
add amp support for custom backend (#96188)
Fixes #ISSUE_NUMBER 1、add amp support for custom backend 2、optimize the file `backend_registration.py`, and rename it with `custom_backend_registration.py`. And then we would register other funcs for custom backend. Pull Request resolved: https://github.com/pytorch/pytorch/pull/96188 Approved by: https://github.com/bdhirsh
This commit is contained in:
@ -46,6 +46,14 @@ void set_hpu_enabled(bool new_enabled) {
|
||||
c10::impl::tls_set_dispatch_key_excluded(DispatchKey::AutocastHPU, !new_enabled);
|
||||
}
|
||||
|
||||
bool is_privateuseone_enabled() {
|
||||
return !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::AutocastPrivateUse1);
|
||||
}
|
||||
|
||||
void set_privateuseone_enabled(bool new_enabled) {
|
||||
c10::impl::tls_set_dispatch_key_excluded(DispatchKey::AutocastPrivateUse1, !new_enabled);
|
||||
}
|
||||
|
||||
namespace {
|
||||
// Imitate Apex and cache some of the casts to streamline parameter reuse.
|
||||
// Our heuristic is to cache lower_precision_fp casts of fp32 model weights (see cached_cast below).
|
||||
@ -88,6 +96,9 @@ thread_local bool cache_enabled = true;
|
||||
|
||||
// autocast_gpu_dtype is the lower_precision_fp used by AutocastGPU.
|
||||
thread_local at::ScalarType autocast_gpu_dtype = at::kHalf;
|
||||
|
||||
// autocast_privateuseone_dtype is the lower_precision_fp used by AutocastPrivateUse1.
|
||||
thread_local at::ScalarType autocast_privateuseone_dtype = at::kHalf;
|
||||
}
|
||||
|
||||
void clear_cache() {
|
||||
@ -119,6 +130,10 @@ at::ScalarType get_autocast_hpu_dtype() {
|
||||
return autocast_hpu_dtype;
|
||||
}
|
||||
|
||||
at::ScalarType get_autocast_privateuseone_dtype() {
|
||||
return autocast_privateuseone_dtype;
|
||||
}
|
||||
|
||||
void set_autocast_cpu_dtype(at::ScalarType dtype) {
|
||||
TORCH_CHECK(
|
||||
dtype == at::kBFloat16,
|
||||
@ -138,6 +153,10 @@ void set_autocast_hpu_dtype(at::ScalarType dtype) {
|
||||
autocast_hpu_dtype = dtype;
|
||||
}
|
||||
|
||||
void set_autocast_privateuseone_dtype(at::ScalarType dtype) {
|
||||
autocast_privateuseone_dtype = dtype;
|
||||
}
|
||||
|
||||
bool is_autocast_cache_enabled() {
|
||||
return cache_enabled;
|
||||
}
|
||||
|
@ -24,6 +24,10 @@ TORCH_API bool is_hpu_enabled();
|
||||
TORCH_API void set_hpu_enabled(bool enabled);
|
||||
TORCH_API at::ScalarType get_autocast_hpu_dtype();
|
||||
TORCH_API void set_autocast_hpu_dtype(at::ScalarType dtype);
|
||||
TORCH_API bool is_privateuseone_enabled();
|
||||
TORCH_API void set_privateuseone_enabled(bool enabled);
|
||||
TORCH_API at::ScalarType get_autocast_privateuseone_dtype();
|
||||
TORCH_API void set_autocast_privateuseone_dtype(at::ScalarType dtype);
|
||||
TORCH_API bool is_autocast_cache_enabled();
|
||||
TORCH_API void set_autocast_cache_enabled(bool enabled);
|
||||
|
||||
@ -40,6 +44,9 @@ bool is_autocast_eligible(const Tensor& tensor, DeviceType device_type) {
|
||||
return tensor.is_xpu() && tensor.is_floating_point();
|
||||
case DeviceType::HPU:
|
||||
return tensor.is_hpu() && tensor.is_floating_point();
|
||||
case DeviceType::PrivateUse1:
|
||||
return tensor.device().type() == DeviceType::PrivateUse1 &&
|
||||
tensor.is_floating_point();
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
@ -57,6 +64,8 @@ inline DispatchKey get_autocast_dispatch_key_from_device_type(
|
||||
return DispatchKey::AutocastXPU;
|
||||
case DeviceType::HPU:
|
||||
return DispatchKey::AutocastHPU;
|
||||
case DeviceType::PrivateUse1:
|
||||
return DispatchKey::AutocastPrivateUse1;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"unknown device type for autocast in get_autocast_dispatch_key_from_device_type");
|
||||
@ -74,6 +83,8 @@ inline at::ScalarType get_lower_precision_fp_from_device_type(
|
||||
return get_autocast_xpu_dtype();
|
||||
case DeviceType::HPU:
|
||||
return get_autocast_hpu_dtype();
|
||||
case DeviceType::PrivateUse1:
|
||||
return get_autocast_privateuseone_dtype();
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"unknown device type for autocast in get_lower_precision_fp_from_device_type");
|
||||
|
@ -144,6 +144,8 @@ const char* toString(DispatchKey t) {
|
||||
return "AutocastHPU";
|
||||
case DispatchKey::AutocastCUDA:
|
||||
return "AutocastCUDA";
|
||||
case DispatchKey::AutocastPrivateUse1:
|
||||
return "AutocastPrivateUse1";
|
||||
|
||||
case DispatchKey::FuncTorchBatched:
|
||||
return "FuncTorchBatched";
|
||||
@ -285,6 +287,7 @@ c10::DispatchKey parseDispatchKey(const std::string& k) {
|
||||
{"AutocastXPU", c10::DispatchKey::AutocastXPU},
|
||||
{"AutocastHPU", c10::DispatchKey::AutocastHPU},
|
||||
{"AutocastCUDA", c10::DispatchKey::AutocastCUDA},
|
||||
{"AutocastPrivateUse1", c10::DispatchKey::AutocastPrivateUse1},
|
||||
{"FuncTorchBatched", c10::DispatchKey::FuncTorchBatched},
|
||||
{"FuncTorchVmapMode", c10::DispatchKey::FuncTorchVmapMode},
|
||||
{"Batched", c10::DispatchKey::Batched},
|
||||
|
@ -356,6 +356,7 @@ enum class DispatchKey : uint16_t {
|
||||
// Naughtily, AutocastCUDA is also being used for XLA. In the terminal state,
|
||||
// it probably should get its own Autocast key
|
||||
AutocastCUDA,
|
||||
AutocastPrivateUse1,
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~ WRAPPERS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ //
|
||||
// There are a number of alternative modes which may want to handle before
|
||||
|
@ -643,6 +643,7 @@ constexpr DispatchKeySet autocast_dispatch_keyset = DispatchKeySet({
|
||||
DispatchKey::AutocastCUDA,
|
||||
DispatchKey::AutocastXPU,
|
||||
DispatchKey::AutocastHPU,
|
||||
DispatchKey::AutocastPrivateUse1,
|
||||
});
|
||||
|
||||
// See Note [TLS Initialization]
|
||||
@ -656,6 +657,7 @@ constexpr DispatchKeySet default_excluded_set = DispatchKeySet({
|
||||
DispatchKey::AutocastCUDA,
|
||||
DispatchKey::AutocastXPU,
|
||||
DispatchKey::AutocastHPU,
|
||||
DispatchKey::AutocastPrivateUse1,
|
||||
});
|
||||
|
||||
constexpr DispatchKeySet autograd_dispatch_keyset_with_ADInplaceOrView =
|
||||
@ -839,6 +841,8 @@ inline DispatchKeySet getAutocastRelatedKeySetFromBackend(BackendComponent t) {
|
||||
constexpr auto autocast_xpu_ks = DispatchKeySet(DispatchKey::AutocastXPU);
|
||||
constexpr auto autocast_hpu_ks = DispatchKeySet(DispatchKey::AutocastHPU);
|
||||
constexpr auto autocast_cuda_ks = DispatchKeySet(DispatchKey::AutocastCUDA);
|
||||
constexpr auto autocast_privateuse1_ks =
|
||||
DispatchKeySet(DispatchKey::AutocastPrivateUse1);
|
||||
switch (t) {
|
||||
case BackendComponent::CPUBit:
|
||||
return autocast_cpu_ks;
|
||||
@ -849,6 +853,8 @@ inline DispatchKeySet getAutocastRelatedKeySetFromBackend(BackendComponent t) {
|
||||
case BackendComponent::CUDABit:
|
||||
case BackendComponent::XLABit:
|
||||
return autocast_cuda_ks;
|
||||
case BackendComponent::PrivateUse1Bit:
|
||||
return autocast_privateuse1_ks;
|
||||
default:
|
||||
return DispatchKeySet();
|
||||
}
|
||||
|
@ -784,6 +784,26 @@ class DummyXPUModule:
|
||||
def is_available():
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def is_autocast_foo_enabled():
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def get_autocast_foo_dtype():
|
||||
return torch.float16
|
||||
|
||||
@staticmethod
|
||||
def set_autocast_foo_enabled(enable):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def set_autocast_foo_dtype(dtype):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def get_amp_supported_dtype():
|
||||
return [torch.float16]
|
||||
|
||||
|
||||
class TestExtensionUtils(TestCase):
|
||||
def test_external_module_register(self):
|
||||
@ -806,6 +826,26 @@ class TestExtensionUtils(TestCase):
|
||||
with self.assertRaisesRegex(RuntimeError, "The runtime module of"):
|
||||
torch._register_device_module('xpu', DummyXPUModule)
|
||||
|
||||
def test_external_module_and_backend_register(self):
|
||||
torch.utils.rename_privateuse1_backend('foo')
|
||||
with self.assertRaisesRegex(RuntimeError, "has already been set"):
|
||||
torch.utils.rename_privateuse1_backend('dummmy')
|
||||
|
||||
custom_backend_name = torch._C._get_privateuse1_backend_name()
|
||||
self.assertEqual(custom_backend_name, 'foo')
|
||||
|
||||
with self.assertRaises(AttributeError):
|
||||
torch.foo.is_available()
|
||||
|
||||
with self.assertRaisesRegex(AssertionError, "Tried to use AMP with the"):
|
||||
with torch.autocast(device_type=custom_backend_name):
|
||||
pass
|
||||
torch._register_device_module('foo', DummyXPUModule)
|
||||
|
||||
torch.foo.is_available()
|
||||
with torch.autocast(device_type=custom_backend_name):
|
||||
pass
|
||||
|
||||
|
||||
class TestDeviceUtils(TestCase):
|
||||
def test_basic(self):
|
||||
|
@ -1233,6 +1233,7 @@ def _get_custom_class_python_wrapper(name: str, attr: str) -> Any: ...
|
||||
|
||||
# Defined in torch/csrc/Module.cpp
|
||||
def _rename_privateuse1_backend(backend: str) -> None: ...
|
||||
def _get_privateuse1_backend_name() -> str: ...
|
||||
|
||||
# Defined in torch/csrc/Generator.cpp
|
||||
class Generator:
|
||||
|
@ -189,6 +189,7 @@ class autocast:
|
||||
assert dtype is not None
|
||||
return
|
||||
self.device = device_type
|
||||
self.custom_backend_name = torch._C._get_privateuse1_backend_name()
|
||||
if self.device == 'cuda':
|
||||
self.fast_dtype = torch.get_autocast_gpu_dtype()
|
||||
elif self.device == 'cpu':
|
||||
@ -197,6 +198,24 @@ class autocast:
|
||||
self.fast_dtype = torch.xpu.get_autocast_xpu_dtype() # type: ignore[attr-defined]
|
||||
elif self.device == 'hpu':
|
||||
self.fast_dtype = torch.hpu.get_autocast_hpu_dtype() # type: ignore[attr-defined]
|
||||
elif self.device == self.custom_backend_name:
|
||||
name_ = self.custom_backend_name
|
||||
necessary_funcs = [f'is_autocast_{name_}_enabled', f'set_autocast_{name_}_enabled',
|
||||
f'get_autocast_{name_}_dtype', f'set_autocast_{name_}_dtype',
|
||||
'get_amp_supported_dtype']
|
||||
message = f"Tried to use AMP with the `{name_}` backend, but the backend has not registered a module or "
|
||||
message += "the module miss some necessary funcs. The backend should register a corresponding module "
|
||||
message += "`torch._register_device_module`, and the module must have these funcs: \n"
|
||||
message += f"`is_autocast_{name_}_enabled() -> bool`, `set_autocast_{name_}_enabled(bool) -> None`, "
|
||||
message += f"`get_autocast_{name_}_dtype() -> torch.dtype`, `set_autocast_{name_}_dtype(torch.dtype) "
|
||||
message += "-> None` and `get_amp_supported_dtype() -> List[torch.dtype]`. \n"
|
||||
|
||||
assert hasattr(torch, name_), message
|
||||
self.custom_device_mod = getattr(torch, name_)
|
||||
for func in necessary_funcs:
|
||||
assert hasattr(self.custom_device_mod, func), message + f"But the func `{func}` is missing. \n"
|
||||
|
||||
self.fast_dtype = getattr(self.custom_device_mod, f'get_autocast_{name_}_dtype')()
|
||||
else:
|
||||
raise RuntimeError('User specified autocast device_type must be \'cuda\' or \'cpu\'')
|
||||
self._cache_enabled = torch.is_autocast_cache_enabled()
|
||||
@ -229,6 +248,14 @@ class autocast:
|
||||
error_message += 'HPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently.'
|
||||
warnings.warn(error_message)
|
||||
enabled = False
|
||||
elif self.device == self.custom_backend_name:
|
||||
supported_dtype = self.custom_device_mod.get_amp_supported_dtype()
|
||||
if self.fast_dtype not in supported_dtype:
|
||||
error_message = f"In {self.custom_backend_name} autocast, but the target dtype is not supported. "
|
||||
error_message += f"Disabling autocast.\n {self.custom_backend_name} Autocast only supports dtypes of "
|
||||
error_message += ", ".join(str(dtype) for dtype in supported_dtype) + " currently."
|
||||
warnings.warn(error_message)
|
||||
enabled = False
|
||||
elif self.device == 'cuda':
|
||||
if self.fast_dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported():
|
||||
raise RuntimeError('Current CUDA Device does not support bfloat16. Please switch dtype to float16.')
|
||||
@ -258,6 +285,12 @@ class autocast:
|
||||
torch.hpu.set_autocast_hpu_enabled(self._enabled) # type: ignore[attr-defined]
|
||||
torch.hpu.set_autocast_hpu_dtype(self.fast_dtype) # type: ignore[attr-defined]
|
||||
torch.autocast_increment_nesting()
|
||||
elif self.device == self.custom_backend_name:
|
||||
self.prev = getattr(self.custom_device_mod, f'is_autocast_{self.custom_backend_name}_enabled')()
|
||||
self.prev_fastdtype = getattr(self.custom_device_mod, f'get_autocast_{self.custom_backend_name}_dtype')()
|
||||
getattr(self.custom_device_mod, f'set_autocast_{self.custom_backend_name}_enabled')(self._enabled)
|
||||
getattr(self.custom_device_mod, f'set_autocast_{self.custom_backend_name}_dtype')(self.fast_dtype)
|
||||
torch.autocast_increment_nesting()
|
||||
else:
|
||||
self.prev = torch.is_autocast_enabled()
|
||||
self.prev_fastdtype = torch.get_autocast_gpu_dtype()
|
||||
@ -286,6 +319,11 @@ class autocast:
|
||||
torch.clear_autocast_cache()
|
||||
torch.hpu.set_autocast_hpu_enabled(self.prev) # type: ignore[attr-defined]
|
||||
torch.hpu.set_autocast_hpu_dtype(self.prev_fastdtype) # type: ignore[attr-defined]
|
||||
elif self.device == self.custom_backend_name:
|
||||
if torch.autocast_decrement_nesting() == 0:
|
||||
torch.clear_autocast_cache()
|
||||
getattr(self.custom_device_mod, f'set_autocast_{self.custom_backend_name}_enabled')(self._enabled)
|
||||
getattr(self.custom_device_mod, f'set_autocast_{self.custom_backend_name}_dtype')(self.fast_dtype)
|
||||
else:
|
||||
if torch.autocast_decrement_nesting() == 0:
|
||||
torch.clear_autocast_cache()
|
||||
|
@ -464,6 +464,7 @@ PyObject* THModule_getCppBacktrace(PyObject* _unused, PyObject* args) {
|
||||
c10::get_backtrace(frames_to_skip, maximum_number_of_frames, true));
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* THModule_rename_privateuse1_backend(
|
||||
PyObject* _unused,
|
||||
PyObject* arg) {
|
||||
@ -479,6 +480,14 @@ static PyObject* THModule_rename_privateuse1_backend(
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject* THModule_get_privateuse1_backend_name(
|
||||
PyObject* _unused,
|
||||
PyObject* arg) {
|
||||
HANDLE_TH_ERRORS
|
||||
return THPUtils_packString(c10::get_privateuse1_backend());
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject* THPModule_setAllowTF32CuDNN(PyObject* _unused, PyObject* arg) {
|
||||
THPUtils_assert(
|
||||
PyBool_Check(arg),
|
||||
@ -1134,6 +1143,10 @@ static PyMethodDef TorchMethods[] = {
|
||||
THModule_rename_privateuse1_backend,
|
||||
METH_O,
|
||||
nullptr},
|
||||
{"_get_privateuse1_backend_name",
|
||||
THModule_get_privateuse1_backend_name,
|
||||
METH_NOARGS,
|
||||
nullptr},
|
||||
{"set_flush_denormal", THPModule_setFlushDenormal, METH_O, nullptr},
|
||||
{"get_default_dtype", THPModule_getDefaultDtype, METH_NOARGS, nullptr},
|
||||
{"_get_default_device", THPModule_getDefaultDevice, METH_NOARGS, nullptr},
|
||||
|
@ -17,6 +17,27 @@ def rename_privateuse1_backend(backend_name: str) -> None:
|
||||
Note: this API can only be called once per process. Attempting to change
|
||||
the external backend after it's already been set will result in an error.
|
||||
|
||||
Note: and if you want to support AMP on your device, you can register a custom backend module.
|
||||
The backend must register a custom backend module with `torch._register_device_module("foo", BackendModule)`.
|
||||
BackendModule needs to have the following API's:
|
||||
|
||||
(1) get_amp_supported_dtype() -> List[torch.dtype]
|
||||
get the supported dtypes on your `foo` device in AMP, maybe the `foo` device supports one more dtype.
|
||||
|
||||
(2) is_autocast_foo_enabled() -> bool
|
||||
check the AMP is enabled or not on your `foo` device.
|
||||
|
||||
(3) get_autocast_foo_dtype() -> torch.dtype
|
||||
get the supported dtype on your `foo` device in AMP, which is set by `set_autocast_foo_dtype` or the
|
||||
default dtype, and the default dtype is `torch.float16`.
|
||||
|
||||
(4) set_autocast_foo_enabled(bool) -> None
|
||||
enable the AMP or not on your `foo` device.
|
||||
|
||||
(5) set_autocast_foo_dtype(dtype) -> None
|
||||
set the supported dtype on your `foo` device in AMP, and the dtype be contained in the dtypes got
|
||||
from `get_amp_supported_dtype`.
|
||||
|
||||
For more details, see https://pytorch.org/tutorials/advanced/extend_dispatcher.html#get-a-dispatch-key-for-your-backend
|
||||
For an existing example, see https://github.com/bdhirsh/pytorch_open_registration_example
|
||||
|
||||
|
Reference in New Issue
Block a user