Add mem efficient backend flag (#87946)

# Summary
Add in a torch.backends.cuda flag and update context manager to pic between the three implementations of the scaled_dot_product_attention.

cc @cpuhrsch @jbschlosser @bhosmer @mikaylagawarecki
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87946
Approved by: https://github.com/cpuhrsch
This commit is contained in:
Driss Guessous
2022-10-28 15:51:10 +00:00
committed by PyTorch MergeBot
parent 89fd451934
commit 35c611d30f
8 changed files with 92 additions and 15 deletions

View File

@ -112,6 +112,14 @@ void Context::setSDPUseFlash(bool e) {
enabled_flashSDP = e;
}
bool Context::userEnabledMemEfficientSDP() const {
return enabled_mem_efficientSDP;
}
void Context::setSDPUseMemEfficient(bool e) {
enabled_mem_efficientSDP = e;
}
bool Context::userEnabledMathSDP() const {
return enabled_mathSDP;
}

View File

@ -128,8 +128,9 @@ class TORCH_API Context {
// Note [Disabling Fused SDP Kernels]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// Flash SDP kernels are enabled by default. However, they can be disabled
// by setting at::globalContext().setUserEnabledFlashSDP(false) flag.
// Flash and Memory Efficient SDP kernels are enabled by default.
// However, they can be disabled by setting
// at::globalContext().setUserEnabledFlashSDP(false) flag.
// This is useful for debugging purposes. For example, if you want to
// compare the performance of the flash SDP kernels with the unfused
// kernel, you can disable the flash SDP kernels. By disabling
@ -139,6 +140,9 @@ class TORCH_API Context {
void setSDPUseFlash(bool);
bool userEnabledFlashSDP() const;
void setSDPUseMemEfficient(bool);
bool userEnabledMemEfficientSDP() const;
void setSDPUseMath(bool);
bool userEnabledMathSDP() const;
@ -270,6 +274,7 @@ class TORCH_API Context {
bool _deterministic_algorithms = false;
bool _deterministic_algorithms_warn_only = false;
bool enabled_flashSDP = true;
bool enabled_mem_efficientSDP = true;
bool enabled_mathSDP = true;
#ifdef USE_ROCM
bool benchmark_cudnn = true;

View File

@ -108,7 +108,7 @@ inline bool check_head_dim_size(sdp_params params, bool debug) {
return true;
}
inline bool check_runtime_disabled(sdp_params params, bool debug) {
inline bool check_runtime_disabled_flash(sdp_params params, bool debug) {
// We check the global context to see if user has explicitly turned of flash
// sdp kernels
if (!at::globalContext().userEnabledFlashSDP()) {
@ -118,6 +118,16 @@ inline bool check_runtime_disabled(sdp_params params, bool debug) {
return true;
}
inline bool check_runtime_disabled_mem_efficient(sdp_params params, bool debug) {
// We check the global context to see if user has explicitly turned of mem_efficient
// sdp kernels
if (!at::globalContext().userEnabledMemEfficientSDP()) {
TORCH_CHECK(!debug, "Memory Efficient attention has been runtime disabled.");
return false;
}
return true;
}
inline bool check_gpu_sm75_or_greater(sdp_params params, bool debug) {
// Check that the gpu is capable of running flash attention
auto dprops = at::cuda::getCurrentDeviceProperties();
@ -164,7 +174,7 @@ inline bool use_flash_attention(sdp_params params, bool debug) {
// Define gate functions that determine if a flash kernel can be ran
std::vector<std::function<bool(sdp_params, bool)>> constraints{
check_runtime_disabled,
check_runtime_disabled_flash,
check_tensor_shapes,
check_for_attn_weights,
check_for_attn_mask,
@ -193,7 +203,7 @@ inline bool use_mem_efficient_attention(sdp_params params, bool debug) {
// Define gate functions that determine if a flash kernel can be ran
std::vector<std::function<bool(sdp_params, bool)>> constraints{
check_gpu_sm50_or_greater,
check_runtime_disabled,
check_runtime_disabled_mem_efficient,
check_for_attn_weights,
check_tensor_shapes,
check_for_attn_mask};
@ -214,7 +224,7 @@ inline SDPBackend select_sdp_backend(sdp_params kernel_params) {
// 2. Mem Efficient Attention
// 3. Math fallback
auto& ctx = at::globalContext();
if (!ctx.userEnabledMathSDP() && !ctx.userEnabledFlashSDP()) {
if (!ctx.userEnabledMathSDP() && !ctx.userEnabledFlashSDP() && !ctx.userEnabledMemEfficientSDP()) {
return SDPBackend::error;
}
// Because TORCHCHECK checks if condition is true we negate debug so that

View File

@ -54,6 +54,10 @@ torch.backends.cuda
.. autofunction:: torch.backends.cuda.flash_sdp_enabled
.. autofunction:: torch.backends.cuda.enable_mem_efficient_sdp
.. autofunction:: torch.backends.cuda.mem_efficient_sdp_enabled
.. autofunction:: torch.backends.cuda.enable_flash_sdp
.. autofunction:: torch.backends.cuda.math_sdp_enabled

View File

@ -1021,12 +1021,12 @@ class TestTransformers(NNTestCase):
def make_tensor(*size, device=device, dtype=dtype):
return torch.randn(size, device=device, dtype=dtype)
with sdp_kernel(enable_flash=False, enable_math=False):
with sdp_kernel(enable_flash=False, enable_math=False, enable_mem_efficient=False):
q, k, v = make_tensor(2, 3, 4), make_tensor(2, 3, 4), make_tensor(2, 3, 4)
self.assertRaisesRegex(RuntimeError, "No viable backend for scaled_dot_product_attention was found.",
lambda: torch.nn.functional._scaled_dot_product_attention(q, k, v))
with sdp_kernel(enable_flash=True, enable_math=False):
with sdp_kernel(enable_flash=True, enable_mem_efficient=False, enable_math=False):
# Failures for invalid input
# Dim is not 4
@ -1035,10 +1035,10 @@ class TestTransformers(NNTestCase):
q, k, v, None, 0.0, False, False))
# Xformers can now cover this case but will add back in next PR
# # Invalid last_dim size
# q, k, v = make_tensor(2, 2, 3, 4), make_tensor(2, 2, 3, 4), make_tensor(2, 2, 3, 4)
# self.assertRaises(RuntimeError, lambda: torch.nn.functional._scaled_dot_product_attention(
# q, k, v, None, 0.0, False, False))
# Invalid last_dim size
q, k, v = make_tensor(2, 2, 3, 4), make_tensor(2, 2, 3, 4), make_tensor(2, 2, 3, 4)
self.assertRaises(RuntimeError, lambda: torch.nn.functional._scaled_dot_product_attention(
q, k, v, None, 0.0, False, False))
# Invalid dtype
q, k, v = make_tensor(2, 2, 3, 16, dtype=torch.float64), make_tensor(
@ -1046,6 +1046,11 @@ class TestTransformers(NNTestCase):
self.assertRaises(RuntimeError, lambda: torch.nn.functional._scaled_dot_product_attention(
q, k, v, None, 0.0, False, False))
q, k, v = make_tensor(2, 2, 3, 16, dtype=torch.float32), make_tensor(
2, 2, 3, 16, dtype=torch.float32), make_tensor(2, 2, 3, 16, dtype=torch.float32)
self.assertRaises(RuntimeError, lambda: torch.nn.functional._scaled_dot_product_attention(
q, k, v, None, 0.0, False, False))
# Failures for unsupported SDP args
q, k, v = make_tensor(2, 2, 3, 16), make_tensor(2, 2, 3, 16), make_tensor(2, 2, 3, 16)

View File

@ -819,6 +819,8 @@ def _get_cudnn_enabled() -> _bool: ... # THPModule_userEnabledCuDNN
def _set_cudnn_enabled(arg: _bool) -> None: ... # THPModule_setUserEnabledCuDNN
def _get_flash_sdp_enabled() -> _bool: ... # THPModule_userEnabledFusedSDP
def _set_sdp_use_flash(arg: _bool) -> None: ... # THPModule_setSDPUseFlash
def _get_mem_efficient_sdp_enabled() -> _bool: ... # THPModule_userEnabledMathSDP
def _set_sdp_use_mem_efficient(arg: _bool) -> None: ... # THPModule_setSDPUseMemEfficient
def _get_math_sdp_enabled() -> _bool: ... # THPModule_userEnabledMathSDP
def _set_sdp_use_math(arg: _bool) -> None: ... # THPModule_setSDPUseMath
def _get_mkldnn_enabled() -> _bool: ... # THPModule_userEnabledMkldnn

View File

@ -6,7 +6,8 @@ from typing import Union
__all__ = ["is_built", "cuFFTPlanCacheAttrContextProp", "cuFFTPlanCache", "cuFFTPlanCacheManager",
"cuBLASModule", "preferred_linalg_library", "cufft_plan_cache", "matmul", "enable_flash_sdp",
"flash_sdp_enabled", "math_sdp_enabled", "enable_math_sdp", "sdp_kernel"]
"flash_sdp_enabled", "enable_mem_efficient_sdp", "mem_efficient_sdp_enabled",
"math_sdp_enabled", "enable_math_sdp", "sdp_kernel"]
def is_built():
r"""Returns whether PyTorch is built with CUDA support. Note that this
@ -180,6 +181,22 @@ def enable_flash_sdp(enabled: bool):
"""
torch._C._set_sdp_use_flash(enabled)
def mem_efficient_sdp_enabled():
r"""
.. warning:: This flag is experimental and subject to change.
Returns whether memory efficient sdp is enabled or not.
"""
return torch._C._get_mem_efficient_sdp_enabled()
def enable_mem_efficient_sdp(enabled: bool):
r"""
.. warning:: This flag is experimental and subject to change.
Enables or disables memory efficient sdp.
"""
torch._C._set_sdp_use_mem_efficient(enabled)
def math_sdp_enabled():
r"""
@ -200,23 +217,26 @@ def enable_math_sdp(enabled: bool):
@contextlib.contextmanager
def sdp_kernel(enable_flash: bool = True, enable_math: bool = True):
def sdp_kernel(enable_flash: bool = True, enable_math: bool = True, enable_mem_efficient: bool = True):
r"""
.. warning:: This flag is experimental and subject to change.
This context manager can be used to temporarily enable or disable flash sdp and math sdp.
This context manager can be used to temporarily enable or disable flash/memory efficient sdp and math sdp.
Upon exiting the context manager, the previous state of the flags will be restored.
"""
previous_flash: bool = flash_sdp_enabled()
previous_mem_efficient: bool = mem_efficient_sdp_enabled()
previous_math: bool = math_sdp_enabled()
try:
enable_flash_sdp(enable_flash)
enable_mem_efficient_sdp(enable_mem_efficient)
enable_math_sdp(enable_math)
yield{}
except RuntimeError as err:
raise err
finally:
enable_flash_sdp(previous_flash)
enable_mem_efficient_sdp(previous_mem_efficient)
enable_math_sdp(previous_math)
cufft_plan_cache = cuFFTPlanCacheManager()

View File

@ -513,6 +513,21 @@ PyObject* THPModule_userEnabledFlashSDP(PyObject* _unused, PyObject* noargs) {
else
Py_RETURN_FALSE;
}
PyObject* THPModule_setSDPUseMemEfficient(PyObject* _unused, PyObject* arg) {
THPUtils_assert(
PyBool_Check(arg),
"set_sdp_use_math expects a bool, "
"but got %s",
THPUtils_typename(arg));
at::globalContext().setSDPUseMemEfficient(arg == Py_True);
Py_RETURN_NONE;
}
PyObject* userEnabledMemEfficientSDP(PyObject* _unused, PyObject* noargs) {
if (at::globalContext().userEnabledMemEfficientSDP())
Py_RETURN_TRUE;
else
Py_RETURN_FALSE;
}
PyObject* THPModule_setSDPUseMath(PyObject* _unused, PyObject* arg) {
THPUtils_assert(
PyBool_Check(arg),
@ -952,6 +967,14 @@ static PyMethodDef TorchMethods[] = {
METH_NOARGS,
nullptr},
{"_set_sdp_use_flash", THPModule_setSDPUseFlash, METH_O, nullptr},
{"_get_mem_efficient_sdp_enabled",
userEnabledMemEfficientSDP,
METH_NOARGS,
nullptr},
{"_set_sdp_use_mem_efficient",
THPModule_setSDPUseMemEfficient,
METH_O,
nullptr},
{"_get_math_sdp_enabled",
THPModule_userEnabledMathSDP,
METH_NOARGS,