mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-03 15:35:04 +08:00
Enable nested default hooks (#70932)
Summary: When default hooks are set, they are pushed onto a stack. When nesting context-manager, only the inner-most hooks will be applied. There is special care needed to update the TLS code. See also https://github.com/pytorch/pytorch/issues/70940 (i.e. do we need to be storing the enabled flag as well?) Fixes https://github.com/pytorch/pytorch/issues/70134 Pull Request resolved: https://github.com/pytorch/pytorch/pull/70932 Reviewed By: mruberry Differential Revision: D33530370 Pulled By: albanD fbshipit-source-id: 3197d585d77563f36c175d3949115a0776b309f4
This commit is contained in:
committed by
Facebook GitHub Bot
parent
433cf44b79
commit
a3b7dd7b78
@ -46,25 +46,21 @@ namespace torch { namespace autograd {
|
||||
}
|
||||
}
|
||||
|
||||
void PyDefaultSavedVariableHooks::set_hooks(py::function &pack_hook, py::function &unpack_hook) {
|
||||
PyObject *pack_hook_(nullptr), *unpack_hook_(nullptr);
|
||||
std::tie(pack_hook_, unpack_hook_) = at::SavedTensorDefaultHooks::get_hooks();
|
||||
TORCH_CHECK(!pack_hook_ && !unpack_hook_,
|
||||
"Setting default hooks but they have already been set. "
|
||||
"Hint: only one pair of hooks is allowed at a time.");
|
||||
void PyDefaultSavedVariableHooks::push_hooks(py::function &pack_hook, py::function &unpack_hook) {
|
||||
at::SavedTensorDefaultHooks::enable();
|
||||
at::SavedTensorDefaultHooks::set_hooks(pack_hook.release().ptr(), unpack_hook.release().ptr());
|
||||
at::SavedTensorDefaultHooks::push_hooks(pack_hook.release().ptr(), unpack_hook.release().ptr());
|
||||
}
|
||||
|
||||
void PyDefaultSavedVariableHooks::reset_hooks() {
|
||||
void PyDefaultSavedVariableHooks::pop_hooks() {
|
||||
PyObject *pack_hook(nullptr), *unpack_hook(nullptr);
|
||||
std::tie(pack_hook, unpack_hook) = at::SavedTensorDefaultHooks::get_hooks();
|
||||
TORCH_INTERNAL_ASSERT(pack_hook != nullptr && unpack_hook != nullptr);
|
||||
if (Py_IsInitialized()) {
|
||||
py::gil_scoped_acquire gil;
|
||||
Py_XDECREF(pack_hook);
|
||||
Py_XDECREF(unpack_hook);
|
||||
}
|
||||
at::SavedTensorDefaultHooks::set_hooks(nullptr, nullptr);
|
||||
at::SavedTensorDefaultHooks::pop_hooks();
|
||||
}
|
||||
|
||||
std::unique_ptr<SavedVariableHooks> PyDefaultSavedVariableHooks::get_hooks() {
|
||||
|
||||
Reference in New Issue
Block a user