Enable nested default hooks (#70932)

Summary:
When default hooks are set, they are pushed onto a stack.
When nesting context-manager, only the inner-most hooks will
be applied.

There is special care needed to update the TLS code. See also https://github.com/pytorch/pytorch/issues/70940 (i.e. do we need to be storing the enabled flag as well?)

Fixes https://github.com/pytorch/pytorch/issues/70134

Pull Request resolved: https://github.com/pytorch/pytorch/pull/70932

Reviewed By: mruberry

Differential Revision: D33530370

Pulled By: albanD

fbshipit-source-id: 3197d585d77563f36c175d3949115a0776b309f4
This commit is contained in:
Victor Quach
2022-01-11 15:02:13 -08:00
committed by Facebook GitHub Bot
parent 433cf44b79
commit a3b7dd7b78
11 changed files with 98 additions and 53 deletions

View File

@ -46,25 +46,21 @@ namespace torch { namespace autograd {
}
}
void PyDefaultSavedVariableHooks::set_hooks(py::function &pack_hook, py::function &unpack_hook) {
PyObject *pack_hook_(nullptr), *unpack_hook_(nullptr);
std::tie(pack_hook_, unpack_hook_) = at::SavedTensorDefaultHooks::get_hooks();
TORCH_CHECK(!pack_hook_ && !unpack_hook_,
"Setting default hooks but they have already been set. "
"Hint: only one pair of hooks is allowed at a time.");
void PyDefaultSavedVariableHooks::push_hooks(py::function &pack_hook, py::function &unpack_hook) {
at::SavedTensorDefaultHooks::enable();
at::SavedTensorDefaultHooks::set_hooks(pack_hook.release().ptr(), unpack_hook.release().ptr());
at::SavedTensorDefaultHooks::push_hooks(pack_hook.release().ptr(), unpack_hook.release().ptr());
}
void PyDefaultSavedVariableHooks::reset_hooks() {
void PyDefaultSavedVariableHooks::pop_hooks() {
PyObject *pack_hook(nullptr), *unpack_hook(nullptr);
std::tie(pack_hook, unpack_hook) = at::SavedTensorDefaultHooks::get_hooks();
TORCH_INTERNAL_ASSERT(pack_hook != nullptr && unpack_hook != nullptr);
if (Py_IsInitialized()) {
py::gil_scoped_acquire gil;
Py_XDECREF(pack_hook);
Py_XDECREF(unpack_hook);
}
at::SavedTensorDefaultHooks::set_hooks(nullptr, nullptr);
at::SavedTensorDefaultHooks::pop_hooks();
}
std::unique_ptr<SavedVariableHooks> PyDefaultSavedVariableHooks::get_hooks() {