mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add methods to access data and unpack_hook on SavedVariable (#164358)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164358 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
39c340ec9e
commit
bac0f289a3
@ -7952,6 +7952,35 @@ for shape in [(1,), ()]:
|
||||
for t in results:
|
||||
self.assertEqual(t.grad_fn._saved_scalars, scalars)
|
||||
|
||||
def test_get_data_and_hooks_from_raw_saved_variable(self):
|
||||
def pack_hook(t):
|
||||
return t
|
||||
|
||||
def unpack_hook(t):
|
||||
return t
|
||||
|
||||
a = torch.tensor(2.0, requires_grad=True)
|
||||
|
||||
with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
|
||||
b = a**2
|
||||
|
||||
c = b.exp()
|
||||
d = c**2
|
||||
|
||||
pow_sv = b.grad_fn._raw_saved_self
|
||||
exp_sv = c.grad_fn._raw_saved_result
|
||||
pow2_sv = d.grad_fn._raw_saved_self
|
||||
|
||||
# Returns the packed object as-is
|
||||
self.assertTrue(pow_sv.data is a)
|
||||
self.assertTrue(pow_sv.unpack_hook is unpack_hook)
|
||||
# Returns the detached data when the output/leaf is saved
|
||||
self.assertFalse(exp_sv.data is c)
|
||||
self.assertIsNone(exp_sv.unpack_hook)
|
||||
# Returns the un-detached data when input is saved
|
||||
self.assertTrue(pow2_sv.data is c)
|
||||
self.assertIsNone(pow2_sv.unpack_hook)
|
||||
|
||||
def test_cant_create_saved_tensors(self):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
|
@ -601,6 +601,33 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) {
|
||||
s.register_hooks(
|
||||
std::make_unique<torch::autograd::PySavedVariableHooks>(
|
||||
pack_hook, unpack_hook));
|
||||
})
|
||||
.def_property_readonly(
|
||||
"data",
|
||||
[](const torch::autograd::SavedVariable& s) -> py::object {
|
||||
if (s.has_hooks()) {
|
||||
auto opt = s.retrieve_unpack_hook_data();
|
||||
TORCH_INTERNAL_ASSERT(opt.has_value());
|
||||
py::gil_scoped_acquire gil;
|
||||
const auto& [_unpack_fn, data_obj] = *opt;
|
||||
PyObject* raw = data_obj.ptr(getPyInterpreter());
|
||||
TORCH_INTERNAL_ASSERT(raw != nullptr);
|
||||
return py::reinterpret_borrow<py::object>(raw);
|
||||
} else {
|
||||
return py::cast(s.get_raw_data().value());
|
||||
}
|
||||
})
|
||||
.def_property_readonly(
|
||||
"unpack_hook",
|
||||
[](const torch::autograd::SavedVariable& s) -> py::object {
|
||||
auto opt = s.retrieve_unpack_hook_data();
|
||||
if (!opt.has_value()) {
|
||||
return py::none();
|
||||
}
|
||||
py::gil_scoped_acquire gil;
|
||||
const auto& [unpack_safe, _unused_data] = *opt;
|
||||
auto* unpack_ptr = unpack_safe.ptr(getPyInterpreter());
|
||||
return py::reinterpret_borrow<py::function>(unpack_ptr);
|
||||
});
|
||||
|
||||
torch::autograd::profiler::python_tracer::init();
|
||||
|
@ -54,6 +54,14 @@ class TORCH_API SavedVariable {
|
||||
return (bool)hooks_;
|
||||
}
|
||||
|
||||
std::optional<at::Tensor> get_raw_data() const {
|
||||
if (hooks_) {
|
||||
return std::nullopt;
|
||||
} else {
|
||||
return data_;
|
||||
}
|
||||
}
|
||||
|
||||
// Used by compiled autograd
|
||||
std::optional<std::pair<c10::SafePyObject, c10::SafePyObject>>
|
||||
retrieve_unpack_hook_data() const {
|
||||
|
Reference in New Issue
Block a user