Files
pytorch/torch/csrc/autograd/python_saved_variable_hooks.h
Victor Quach a3b7dd7b78 Enable nested default hooks (#70932)
Summary:
When default hooks are set, they are pushed onto a stack.
When nesting context-manager, only the inner-most hooks will
be applied.

There is special care needed to update the TLS code. See also https://github.com/pytorch/pytorch/issues/70940 (i.e. do we need to be storing the enabled flag as well?)

Fixes https://github.com/pytorch/pytorch/issues/70134

Pull Request resolved: https://github.com/pytorch/pytorch/pull/70932

Reviewed By: mruberry

Differential Revision: D33530370

Pulled By: albanD

fbshipit-source-id: 3197d585d77563f36c175d3949115a0776b309f4
2022-01-11 15:03:49 -08:00

33 lines
878 B
C++

#pragma once
#include <pybind11/pybind11.h>
#include <torch/csrc/autograd/python_variable.h>
#include <torch/csrc/autograd/saved_variable_hooks.h>
#include <torch/csrc/python_headers.h>
#include <torch/csrc/Export.h>
#include <ATen/ATen.h>
namespace py = pybind11;
namespace torch { namespace autograd {
struct PySavedVariableHooks : public SavedVariableHooks {
PySavedVariableHooks(py::function &pack_hook, py::function &unpack_hook);
void call_pack_hook(const at::Tensor &tensor) override;
at::Tensor call_unpack_hook() override;
~PySavedVariableHooks() override;
private:
PyObject* pack_hook_;
PyObject* unpack_hook_;
PyObject* data_ = nullptr;
};
struct PyDefaultSavedVariableHooks {
static void push_hooks(py::function &pack_hook, py::function &unpack_hook);
static void pop_hooks();
static std::unique_ptr<SavedVariableHooks> get_hooks();
};
}}