mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add mem efficient backend flag (#87946)
# Summary Add in a torch.backends.cuda flag and update context manager to pic between the three implementations of the scaled_dot_product_attention. cc @cpuhrsch @jbschlosser @bhosmer @mikaylagawarecki Pull Request resolved: https://github.com/pytorch/pytorch/pull/87946 Approved by: https://github.com/cpuhrsch
This commit is contained in:
committed by
PyTorch MergeBot
parent
89fd451934
commit
35c611d30f
@ -112,6 +112,14 @@ void Context::setSDPUseFlash(bool e) {
|
||||
enabled_flashSDP = e;
|
||||
}
|
||||
|
||||
bool Context::userEnabledMemEfficientSDP() const {
|
||||
return enabled_mem_efficientSDP;
|
||||
}
|
||||
|
||||
void Context::setSDPUseMemEfficient(bool e) {
|
||||
enabled_mem_efficientSDP = e;
|
||||
}
|
||||
|
||||
bool Context::userEnabledMathSDP() const {
|
||||
return enabled_mathSDP;
|
||||
}
|
||||
|
@ -128,8 +128,9 @@ class TORCH_API Context {
|
||||
|
||||
// Note [Disabling Fused SDP Kernels]
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
// Flash SDP kernels are enabled by default. However, they can be disabled
|
||||
// by setting at::globalContext().setUserEnabledFlashSDP(false) flag.
|
||||
// Flash and Memory Efficient SDP kernels are enabled by default.
|
||||
// However, they can be disabled by setting
|
||||
// at::globalContext().setUserEnabledFlashSDP(false) flag.
|
||||
// This is useful for debugging purposes. For example, if you want to
|
||||
// compare the performance of the flash SDP kernels with the unfused
|
||||
// kernel, you can disable the flash SDP kernels. By disabling
|
||||
@ -139,6 +140,9 @@ class TORCH_API Context {
|
||||
void setSDPUseFlash(bool);
|
||||
bool userEnabledFlashSDP() const;
|
||||
|
||||
void setSDPUseMemEfficient(bool);
|
||||
bool userEnabledMemEfficientSDP() const;
|
||||
|
||||
void setSDPUseMath(bool);
|
||||
bool userEnabledMathSDP() const;
|
||||
|
||||
@ -270,6 +274,7 @@ class TORCH_API Context {
|
||||
bool _deterministic_algorithms = false;
|
||||
bool _deterministic_algorithms_warn_only = false;
|
||||
bool enabled_flashSDP = true;
|
||||
bool enabled_mem_efficientSDP = true;
|
||||
bool enabled_mathSDP = true;
|
||||
#ifdef USE_ROCM
|
||||
bool benchmark_cudnn = true;
|
||||
|
@ -108,7 +108,7 @@ inline bool check_head_dim_size(sdp_params params, bool debug) {
|
||||
return true;
|
||||
}
|
||||
|
||||
inline bool check_runtime_disabled(sdp_params params, bool debug) {
|
||||
inline bool check_runtime_disabled_flash(sdp_params params, bool debug) {
|
||||
// We check the global context to see if user has explicitly turned of flash
|
||||
// sdp kernels
|
||||
if (!at::globalContext().userEnabledFlashSDP()) {
|
||||
@ -118,6 +118,16 @@ inline bool check_runtime_disabled(sdp_params params, bool debug) {
|
||||
return true;
|
||||
}
|
||||
|
||||
inline bool check_runtime_disabled_mem_efficient(sdp_params params, bool debug) {
|
||||
// We check the global context to see if user has explicitly turned of mem_efficient
|
||||
// sdp kernels
|
||||
if (!at::globalContext().userEnabledMemEfficientSDP()) {
|
||||
TORCH_CHECK(!debug, "Memory Efficient attention has been runtime disabled.");
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
inline bool check_gpu_sm75_or_greater(sdp_params params, bool debug) {
|
||||
// Check that the gpu is capable of running flash attention
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
@ -164,7 +174,7 @@ inline bool use_flash_attention(sdp_params params, bool debug) {
|
||||
|
||||
// Define gate functions that determine if a flash kernel can be ran
|
||||
std::vector<std::function<bool(sdp_params, bool)>> constraints{
|
||||
check_runtime_disabled,
|
||||
check_runtime_disabled_flash,
|
||||
check_tensor_shapes,
|
||||
check_for_attn_weights,
|
||||
check_for_attn_mask,
|
||||
@ -193,7 +203,7 @@ inline bool use_mem_efficient_attention(sdp_params params, bool debug) {
|
||||
// Define gate functions that determine if a flash kernel can be ran
|
||||
std::vector<std::function<bool(sdp_params, bool)>> constraints{
|
||||
check_gpu_sm50_or_greater,
|
||||
check_runtime_disabled,
|
||||
check_runtime_disabled_mem_efficient,
|
||||
check_for_attn_weights,
|
||||
check_tensor_shapes,
|
||||
check_for_attn_mask};
|
||||
@ -214,7 +224,7 @@ inline SDPBackend select_sdp_backend(sdp_params kernel_params) {
|
||||
// 2. Mem Efficient Attention
|
||||
// 3. Math fallback
|
||||
auto& ctx = at::globalContext();
|
||||
if (!ctx.userEnabledMathSDP() && !ctx.userEnabledFlashSDP()) {
|
||||
if (!ctx.userEnabledMathSDP() && !ctx.userEnabledFlashSDP() && !ctx.userEnabledMemEfficientSDP()) {
|
||||
return SDPBackend::error;
|
||||
}
|
||||
// Because TORCHCHECK checks if condition is true we negate debug so that
|
||||
|
@ -54,6 +54,10 @@ torch.backends.cuda
|
||||
|
||||
.. autofunction:: torch.backends.cuda.flash_sdp_enabled
|
||||
|
||||
.. autofunction:: torch.backends.cuda.enable_mem_efficient_sdp
|
||||
|
||||
.. autofunction:: torch.backends.cuda.mem_efficient_sdp_enabled
|
||||
|
||||
.. autofunction:: torch.backends.cuda.enable_flash_sdp
|
||||
|
||||
.. autofunction:: torch.backends.cuda.math_sdp_enabled
|
||||
|
@ -1021,12 +1021,12 @@ class TestTransformers(NNTestCase):
|
||||
def make_tensor(*size, device=device, dtype=dtype):
|
||||
return torch.randn(size, device=device, dtype=dtype)
|
||||
|
||||
with sdp_kernel(enable_flash=False, enable_math=False):
|
||||
with sdp_kernel(enable_flash=False, enable_math=False, enable_mem_efficient=False):
|
||||
q, k, v = make_tensor(2, 3, 4), make_tensor(2, 3, 4), make_tensor(2, 3, 4)
|
||||
self.assertRaisesRegex(RuntimeError, "No viable backend for scaled_dot_product_attention was found.",
|
||||
lambda: torch.nn.functional._scaled_dot_product_attention(q, k, v))
|
||||
|
||||
with sdp_kernel(enable_flash=True, enable_math=False):
|
||||
with sdp_kernel(enable_flash=True, enable_mem_efficient=False, enable_math=False):
|
||||
# Failures for invalid input
|
||||
|
||||
# Dim is not 4
|
||||
@ -1035,10 +1035,10 @@ class TestTransformers(NNTestCase):
|
||||
q, k, v, None, 0.0, False, False))
|
||||
|
||||
# Xformers can now cover this case but will add back in next PR
|
||||
# # Invalid last_dim size
|
||||
# q, k, v = make_tensor(2, 2, 3, 4), make_tensor(2, 2, 3, 4), make_tensor(2, 2, 3, 4)
|
||||
# self.assertRaises(RuntimeError, lambda: torch.nn.functional._scaled_dot_product_attention(
|
||||
# q, k, v, None, 0.0, False, False))
|
||||
# Invalid last_dim size
|
||||
q, k, v = make_tensor(2, 2, 3, 4), make_tensor(2, 2, 3, 4), make_tensor(2, 2, 3, 4)
|
||||
self.assertRaises(RuntimeError, lambda: torch.nn.functional._scaled_dot_product_attention(
|
||||
q, k, v, None, 0.0, False, False))
|
||||
|
||||
# Invalid dtype
|
||||
q, k, v = make_tensor(2, 2, 3, 16, dtype=torch.float64), make_tensor(
|
||||
@ -1046,6 +1046,11 @@ class TestTransformers(NNTestCase):
|
||||
self.assertRaises(RuntimeError, lambda: torch.nn.functional._scaled_dot_product_attention(
|
||||
q, k, v, None, 0.0, False, False))
|
||||
|
||||
q, k, v = make_tensor(2, 2, 3, 16, dtype=torch.float32), make_tensor(
|
||||
2, 2, 3, 16, dtype=torch.float32), make_tensor(2, 2, 3, 16, dtype=torch.float32)
|
||||
self.assertRaises(RuntimeError, lambda: torch.nn.functional._scaled_dot_product_attention(
|
||||
q, k, v, None, 0.0, False, False))
|
||||
|
||||
# Failures for unsupported SDP args
|
||||
q, k, v = make_tensor(2, 2, 3, 16), make_tensor(2, 2, 3, 16), make_tensor(2, 2, 3, 16)
|
||||
|
||||
|
@ -819,6 +819,8 @@ def _get_cudnn_enabled() -> _bool: ... # THPModule_userEnabledCuDNN
|
||||
def _set_cudnn_enabled(arg: _bool) -> None: ... # THPModule_setUserEnabledCuDNN
|
||||
def _get_flash_sdp_enabled() -> _bool: ... # THPModule_userEnabledFusedSDP
|
||||
def _set_sdp_use_flash(arg: _bool) -> None: ... # THPModule_setSDPUseFlash
|
||||
def _get_mem_efficient_sdp_enabled() -> _bool: ... # THPModule_userEnabledMathSDP
|
||||
def _set_sdp_use_mem_efficient(arg: _bool) -> None: ... # THPModule_setSDPUseMemEfficient
|
||||
def _get_math_sdp_enabled() -> _bool: ... # THPModule_userEnabledMathSDP
|
||||
def _set_sdp_use_math(arg: _bool) -> None: ... # THPModule_setSDPUseMath
|
||||
def _get_mkldnn_enabled() -> _bool: ... # THPModule_userEnabledMkldnn
|
||||
|
@ -6,7 +6,8 @@ from typing import Union
|
||||
|
||||
__all__ = ["is_built", "cuFFTPlanCacheAttrContextProp", "cuFFTPlanCache", "cuFFTPlanCacheManager",
|
||||
"cuBLASModule", "preferred_linalg_library", "cufft_plan_cache", "matmul", "enable_flash_sdp",
|
||||
"flash_sdp_enabled", "math_sdp_enabled", "enable_math_sdp", "sdp_kernel"]
|
||||
"flash_sdp_enabled", "enable_mem_efficient_sdp", "mem_efficient_sdp_enabled",
|
||||
"math_sdp_enabled", "enable_math_sdp", "sdp_kernel"]
|
||||
|
||||
def is_built():
|
||||
r"""Returns whether PyTorch is built with CUDA support. Note that this
|
||||
@ -180,6 +181,22 @@ def enable_flash_sdp(enabled: bool):
|
||||
"""
|
||||
torch._C._set_sdp_use_flash(enabled)
|
||||
|
||||
def mem_efficient_sdp_enabled():
|
||||
r"""
|
||||
.. warning:: This flag is experimental and subject to change.
|
||||
|
||||
Returns whether memory efficient sdp is enabled or not.
|
||||
"""
|
||||
return torch._C._get_mem_efficient_sdp_enabled()
|
||||
|
||||
|
||||
def enable_mem_efficient_sdp(enabled: bool):
|
||||
r"""
|
||||
.. warning:: This flag is experimental and subject to change.
|
||||
|
||||
Enables or disables memory efficient sdp.
|
||||
"""
|
||||
torch._C._set_sdp_use_mem_efficient(enabled)
|
||||
|
||||
def math_sdp_enabled():
|
||||
r"""
|
||||
@ -200,23 +217,26 @@ def enable_math_sdp(enabled: bool):
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def sdp_kernel(enable_flash: bool = True, enable_math: bool = True):
|
||||
def sdp_kernel(enable_flash: bool = True, enable_math: bool = True, enable_mem_efficient: bool = True):
|
||||
r"""
|
||||
.. warning:: This flag is experimental and subject to change.
|
||||
|
||||
This context manager can be used to temporarily enable or disable flash sdp and math sdp.
|
||||
This context manager can be used to temporarily enable or disable flash/memory efficient sdp and math sdp.
|
||||
Upon exiting the context manager, the previous state of the flags will be restored.
|
||||
"""
|
||||
previous_flash: bool = flash_sdp_enabled()
|
||||
previous_mem_efficient: bool = mem_efficient_sdp_enabled()
|
||||
previous_math: bool = math_sdp_enabled()
|
||||
try:
|
||||
enable_flash_sdp(enable_flash)
|
||||
enable_mem_efficient_sdp(enable_mem_efficient)
|
||||
enable_math_sdp(enable_math)
|
||||
yield{}
|
||||
except RuntimeError as err:
|
||||
raise err
|
||||
finally:
|
||||
enable_flash_sdp(previous_flash)
|
||||
enable_mem_efficient_sdp(previous_mem_efficient)
|
||||
enable_math_sdp(previous_math)
|
||||
|
||||
cufft_plan_cache = cuFFTPlanCacheManager()
|
||||
|
@ -513,6 +513,21 @@ PyObject* THPModule_userEnabledFlashSDP(PyObject* _unused, PyObject* noargs) {
|
||||
else
|
||||
Py_RETURN_FALSE;
|
||||
}
|
||||
PyObject* THPModule_setSDPUseMemEfficient(PyObject* _unused, PyObject* arg) {
|
||||
THPUtils_assert(
|
||||
PyBool_Check(arg),
|
||||
"set_sdp_use_math expects a bool, "
|
||||
"but got %s",
|
||||
THPUtils_typename(arg));
|
||||
at::globalContext().setSDPUseMemEfficient(arg == Py_True);
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
PyObject* userEnabledMemEfficientSDP(PyObject* _unused, PyObject* noargs) {
|
||||
if (at::globalContext().userEnabledMemEfficientSDP())
|
||||
Py_RETURN_TRUE;
|
||||
else
|
||||
Py_RETURN_FALSE;
|
||||
}
|
||||
PyObject* THPModule_setSDPUseMath(PyObject* _unused, PyObject* arg) {
|
||||
THPUtils_assert(
|
||||
PyBool_Check(arg),
|
||||
@ -952,6 +967,14 @@ static PyMethodDef TorchMethods[] = {
|
||||
METH_NOARGS,
|
||||
nullptr},
|
||||
{"_set_sdp_use_flash", THPModule_setSDPUseFlash, METH_O, nullptr},
|
||||
{"_get_mem_efficient_sdp_enabled",
|
||||
userEnabledMemEfficientSDP,
|
||||
METH_NOARGS,
|
||||
nullptr},
|
||||
{"_set_sdp_use_mem_efficient",
|
||||
THPModule_setSDPUseMemEfficient,
|
||||
METH_O,
|
||||
nullptr},
|
||||
{"_get_math_sdp_enabled",
|
||||
THPModule_userEnabledMathSDP,
|
||||
METH_NOARGS,
|
||||
|
Reference in New Issue
Block a user