mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-27 17:54:55 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/62563 Expose a pair of functions to Python users: torch.autograd.graph.set_saved_tensors_default_hooks(pack, unpack) and torch.autograd.graph.reset_saved_tensors_default_hooks(). These functions control the hooks applied to saved tensors: all tensors saved in that context will be packed using the pack function, then unpacked accordingly when needed. Currently, this works by simply calling register_hooks (cf #60975) directly at the end of the constructor of a SavedVariable. This could be optimized further by not performing the copy before registering default hooks, but this would require a small refactor. Edit: the refactor is done in #61927. A current limitation is that if users create tensors in this context, they will not be able to register additional hooks on the saved tensor. For instance, to perform something like #28997, one could define a pack function that saves to disk whenever the tensor size is too big and returns a filename, then unpack simply reads the content of the file and outputs a tensor, e.g.: ``` def pack(x): name = os.path.join(tmp_dir, str(uuid.uuid4())) torch.save(x, name) return name def unpack(name): return torch.load(name) ``` Relanding previous PR: https://github.com/pytorch/pytorch/pull/61834 Original PR led to timeout error in: https://www.internalfb.com/mast/job/yuguo-release_canary_offline_training-inlinecvrp_a-canary_offline_train_28a7ecfc Now passing: https://www.internalfb.com/mast/job/quach-release_canary_offline_training-inlinecvrp_a-canary_offline_train_9bb57e98 The difference with the new version is we don't need to acquire the GIL when calling `PyDefaultSavedVariableHooks::get_hooks`. Test Plan: Imported from OSS Reviewed By: iramazanli Differential Revision: D30045405 Pulled By: Varal7 fbshipit-source-id: 7f6c07af3a56fe8835d5edcc815c15ea4fb4e332
82 lines
3.1 KiB
C++
82 lines
3.1 KiB
C++
#include <torch/csrc/autograd/python_saved_variable_hooks.h>
|
|
|
|
#include <torch/csrc/THP.h>
|
|
|
|
namespace py = pybind11;
|
|
|
|
namespace torch { namespace autograd {
|
|
PySavedVariableHooks::PySavedVariableHooks(py::function &pack_hook, py::function &unpack_hook) :
|
|
// steals the reference (we will decref ourselves)
|
|
pack_hook_(pack_hook.release().ptr()),
|
|
unpack_hook_(unpack_hook.release().ptr()) {}
|
|
|
|
void PySavedVariableHooks::call_pack_hook(at::Tensor &tensor) {
|
|
py::gil_scoped_acquire acquire;
|
|
auto pack_hook = py::reinterpret_borrow<py::function>(pack_hook_);
|
|
auto wrapped = THPVariable_Wrap(tensor);
|
|
py::object obj = py::reinterpret_steal<py::object>(wrapped);
|
|
py::object packed = pack_hook(obj);
|
|
data_ = packed.release().ptr();
|
|
// pack_hook, obj are decrefed on exit
|
|
// wrapped and packed had their references stolen
|
|
// pack_hook_ and data_ will be manually decrefed when the saved variable is released
|
|
}
|
|
|
|
at::Tensor PySavedVariableHooks::call_unpack_hook() {
|
|
py::gil_scoped_acquire acquire;
|
|
auto unpack_hook = py::reinterpret_borrow<py::function>(unpack_hook_);
|
|
py::object obj = py::cast<py::object>(data_);
|
|
py::object res = unpack_hook(obj);
|
|
PyObject* ptr = res.ptr();
|
|
TORCH_CHECK_TYPE(THPVariable_Check(ptr), "Output of saved tensor unpack_hook expected to be a Tensor but got result of type ", THPUtils_typename(ptr));
|
|
return THPVariable_Unpack(ptr);
|
|
// unpack_hook, obj and res are decrefed on exit
|
|
// ptr is only alive as long as res is
|
|
// unpack_hook_ will be manually decrefed when the saved variable is released
|
|
}
|
|
|
|
PySavedVariableHooks::~PySavedVariableHooks() {
|
|
// If python is already dead, leak the wrapped python objects
|
|
if (Py_IsInitialized()) {
|
|
py::gil_scoped_acquire gil;
|
|
Py_XDECREF(pack_hook_);
|
|
Py_XDECREF(unpack_hook_);
|
|
Py_XDECREF(data_);
|
|
}
|
|
}
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
|
PyObject* PyDefaultSavedVariableHooks::pack_hook_ = nullptr;
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
|
PyObject* PyDefaultSavedVariableHooks::unpack_hook_ = nullptr;
|
|
|
|
void PyDefaultSavedVariableHooks::set_hooks(py::function &pack_hook, py::function &unpack_hook) {
|
|
TORCH_CHECK(!pack_hook_ && !unpack_hook_,
|
|
"Setting default hooks but they have already been set. "
|
|
"Hint: only one pair of hooks is allowed at a time.");
|
|
pack_hook_ = pack_hook.release().ptr();
|
|
unpack_hook_ = unpack_hook.release().ptr();
|
|
}
|
|
|
|
void PyDefaultSavedVariableHooks::reset_hooks() {
|
|
if (Py_IsInitialized()) {
|
|
py::gil_scoped_acquire gil;
|
|
Py_XDECREF(pack_hook_);
|
|
Py_XDECREF(unpack_hook_);
|
|
}
|
|
pack_hook_ = nullptr;
|
|
unpack_hook_ = nullptr;
|
|
}
|
|
|
|
std::unique_ptr<SavedVariableHooks> PyDefaultSavedVariableHooks::get_hooks() {
|
|
if (!pack_hook_ || !unpack_hook_) {
|
|
return nullptr;
|
|
}
|
|
py::gil_scoped_acquire gil;
|
|
py::function pack_hook = py::reinterpret_borrow<py::function>(pack_hook_);
|
|
py::function unpack_hook = py::reinterpret_borrow<py::function>(unpack_hook_);
|
|
return std::make_unique<PySavedVariableHooks>(pack_hook, unpack_hook);
|
|
}
|
|
|
|
}}
|