mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Add mem efficient backend flag (#87946)
# Summary Add in a torch.backends.cuda flag and update context manager to pic between the three implementations of the scaled_dot_product_attention. cc @cpuhrsch @jbschlosser @bhosmer @mikaylagawarecki Pull Request resolved: https://github.com/pytorch/pytorch/pull/87946 Approved by: https://github.com/cpuhrsch
This commit is contained in:
committed by
PyTorch MergeBot
parent
89fd451934
commit
35c611d30f
@ -513,6 +513,21 @@ PyObject* THPModule_userEnabledFlashSDP(PyObject* _unused, PyObject* noargs) {
|
||||
else
|
||||
Py_RETURN_FALSE;
|
||||
}
|
||||
PyObject* THPModule_setSDPUseMemEfficient(PyObject* _unused, PyObject* arg) {
|
||||
THPUtils_assert(
|
||||
PyBool_Check(arg),
|
||||
"set_sdp_use_math expects a bool, "
|
||||
"but got %s",
|
||||
THPUtils_typename(arg));
|
||||
at::globalContext().setSDPUseMemEfficient(arg == Py_True);
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
PyObject* userEnabledMemEfficientSDP(PyObject* _unused, PyObject* noargs) {
|
||||
if (at::globalContext().userEnabledMemEfficientSDP())
|
||||
Py_RETURN_TRUE;
|
||||
else
|
||||
Py_RETURN_FALSE;
|
||||
}
|
||||
PyObject* THPModule_setSDPUseMath(PyObject* _unused, PyObject* arg) {
|
||||
THPUtils_assert(
|
||||
PyBool_Check(arg),
|
||||
@ -952,6 +967,14 @@ static PyMethodDef TorchMethods[] = {
|
||||
METH_NOARGS,
|
||||
nullptr},
|
||||
{"_set_sdp_use_flash", THPModule_setSDPUseFlash, METH_O, nullptr},
|
||||
{"_get_mem_efficient_sdp_enabled",
|
||||
userEnabledMemEfficientSDP,
|
||||
METH_NOARGS,
|
||||
nullptr},
|
||||
{"_set_sdp_use_mem_efficient",
|
||||
THPModule_setSDPUseMemEfficient,
|
||||
METH_O,
|
||||
nullptr},
|
||||
{"_get_math_sdp_enabled",
|
||||
THPModule_userEnabledMathSDP,
|
||||
METH_NOARGS,
|
||||
|
Reference in New Issue
Block a user