Files
pytorch/torch/csrc/autograd/python_saved_variable_hooks.cpp
Victor Quach b161ac541d [reland] Add default Saved Variable hooks (#62563)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62563

Expose a pair of functions to Python users: torch.autograd.graph.set_saved_tensors_default_hooks(pack, unpack) and torch.autograd.graph.reset_saved_tensors_default_hooks().
These functions control the hooks applied to saved tensors: all tensors saved in that context will be packed using the pack function, then unpacked accordingly when needed.

Currently, this works by simply calling register_hooks (cf #60975) directly at the end of the constructor of a SavedVariable. This could be optimized further by not performing the copy before registering default hooks, but this would require a small refactor. Edit: the refactor is done in #61927.

A current limitation is that if users create tensors in this context, they will not be able to register additional hooks on the saved tensor.

For instance, to perform something like #28997, one could define a pack function that saves to disk whenever the tensor size is too big and returns a filename, then unpack simply reads the content of the file and outputs a tensor, e.g.:

```
def pack(x):
    name = os.path.join(tmp_dir, str(uuid.uuid4()))
    torch.save(x, name)
    return name

def unpack(name):
    return torch.load(name)
```

Relanding previous PR: https://github.com/pytorch/pytorch/pull/61834

Original PR led to timeout error in: https://www.internalfb.com/mast/job/yuguo-release_canary_offline_training-inlinecvrp_a-canary_offline_train_28a7ecfc

Now passing: https://www.internalfb.com/mast/job/quach-release_canary_offline_training-inlinecvrp_a-canary_offline_train_9bb57e98

The difference with the new version is we don't need to acquire the GIL when calling `PyDefaultSavedVariableHooks::get_hooks`.

Test Plan: Imported from OSS

Reviewed By: iramazanli

Differential Revision: D30045405

Pulled By: Varal7

fbshipit-source-id: 7f6c07af3a56fe8835d5edcc815c15ea4fb4e332
2021-08-02 11:30:26 -07:00

82 lines
3.1 KiB
C++

#include <torch/csrc/autograd/python_saved_variable_hooks.h>
#include <torch/csrc/THP.h>
namespace py = pybind11;
namespace torch { namespace autograd {
PySavedVariableHooks::PySavedVariableHooks(py::function &pack_hook, py::function &unpack_hook) :
// steals the reference (we will decref ourselves)
pack_hook_(pack_hook.release().ptr()),
unpack_hook_(unpack_hook.release().ptr()) {}
void PySavedVariableHooks::call_pack_hook(at::Tensor &tensor) {
py::gil_scoped_acquire acquire;
auto pack_hook = py::reinterpret_borrow<py::function>(pack_hook_);
auto wrapped = THPVariable_Wrap(tensor);
py::object obj = py::reinterpret_steal<py::object>(wrapped);
py::object packed = pack_hook(obj);
data_ = packed.release().ptr();
// pack_hook, obj are decrefed on exit
// wrapped and packed had their references stolen
// pack_hook_ and data_ will be manually decrefed when the saved variable is released
}
at::Tensor PySavedVariableHooks::call_unpack_hook() {
py::gil_scoped_acquire acquire;
auto unpack_hook = py::reinterpret_borrow<py::function>(unpack_hook_);
py::object obj = py::cast<py::object>(data_);
py::object res = unpack_hook(obj);
PyObject* ptr = res.ptr();
TORCH_CHECK_TYPE(THPVariable_Check(ptr), "Output of saved tensor unpack_hook expected to be a Tensor but got result of type ", THPUtils_typename(ptr));
return THPVariable_Unpack(ptr);
// unpack_hook, obj and res are decrefed on exit
// ptr is only alive as long as res is
// unpack_hook_ will be manually decrefed when the saved variable is released
}
PySavedVariableHooks::~PySavedVariableHooks() {
// If python is already dead, leak the wrapped python objects
if (Py_IsInitialized()) {
py::gil_scoped_acquire gil;
Py_XDECREF(pack_hook_);
Py_XDECREF(unpack_hook_);
Py_XDECREF(data_);
}
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
PyObject* PyDefaultSavedVariableHooks::pack_hook_ = nullptr;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
PyObject* PyDefaultSavedVariableHooks::unpack_hook_ = nullptr;
void PyDefaultSavedVariableHooks::set_hooks(py::function &pack_hook, py::function &unpack_hook) {
TORCH_CHECK(!pack_hook_ && !unpack_hook_,
"Setting default hooks but they have already been set. "
"Hint: only one pair of hooks is allowed at a time.");
pack_hook_ = pack_hook.release().ptr();
unpack_hook_ = unpack_hook.release().ptr();
}
void PyDefaultSavedVariableHooks::reset_hooks() {
if (Py_IsInitialized()) {
py::gil_scoped_acquire gil;
Py_XDECREF(pack_hook_);
Py_XDECREF(unpack_hook_);
}
pack_hook_ = nullptr;
unpack_hook_ = nullptr;
}
std::unique_ptr<SavedVariableHooks> PyDefaultSavedVariableHooks::get_hooks() {
if (!pack_hook_ || !unpack_hook_) {
return nullptr;
}
py::gil_scoped_acquire gil;
py::function pack_hook = py::reinterpret_borrow<py::function>(pack_hook_);
py::function unpack_hook = py::reinterpret_borrow<py::function>(unpack_hook_);
return std::make_unique<PySavedVariableHooks>(pack_hook, unpack_hook);
}
}}