Add mem efficient backend flag (#87946)

# Summary
Add in a torch.backends.cuda flag and update context manager to pic between the three implementations of the scaled_dot_product_attention.

cc @cpuhrsch @jbschlosser @bhosmer @mikaylagawarecki
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87946
Approved by: https://github.com/cpuhrsch
This commit is contained in:
Driss Guessous
2022-10-28 15:51:10 +00:00
committed by PyTorch MergeBot
parent 89fd451934
commit 35c611d30f
8 changed files with 92 additions and 15 deletions

View File

@ -513,6 +513,21 @@ PyObject* THPModule_userEnabledFlashSDP(PyObject* _unused, PyObject* noargs) {
else
Py_RETURN_FALSE;
}
PyObject* THPModule_setSDPUseMemEfficient(PyObject* _unused, PyObject* arg) {
THPUtils_assert(
PyBool_Check(arg),
"set_sdp_use_math expects a bool, "
"but got %s",
THPUtils_typename(arg));
at::globalContext().setSDPUseMemEfficient(arg == Py_True);
Py_RETURN_NONE;
}
PyObject* userEnabledMemEfficientSDP(PyObject* _unused, PyObject* noargs) {
if (at::globalContext().userEnabledMemEfficientSDP())
Py_RETURN_TRUE;
else
Py_RETURN_FALSE;
}
PyObject* THPModule_setSDPUseMath(PyObject* _unused, PyObject* arg) {
THPUtils_assert(
PyBool_Check(arg),
@ -952,6 +967,14 @@ static PyMethodDef TorchMethods[] = {
METH_NOARGS,
nullptr},
{"_set_sdp_use_flash", THPModule_setSDPUseFlash, METH_O, nullptr},
{"_get_mem_efficient_sdp_enabled",
userEnabledMemEfficientSDP,
METH_NOARGS,
nullptr},
{"_set_sdp_use_mem_efficient",
THPModule_setSDPUseMemEfficient,
METH_O,
nullptr},
{"_get_math_sdp_enabled",
THPModule_userEnabledMathSDP,
METH_NOARGS,