Add methods to access data and unpack_hook on SavedVariable (#164358)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164358
Approved by: https://github.com/albanD
This commit is contained in:
soulitzer
2025-10-01 11:07:27 -07:00
committed by PyTorch MergeBot
parent 39c340ec9e
commit bac0f289a3
3 changed files with 64 additions and 0 deletions

View File

@ -7952,6 +7952,35 @@ for shape in [(1,), ()]:
for t in results:
self.assertEqual(t.grad_fn._saved_scalars, scalars)
def test_get_data_and_hooks_from_raw_saved_variable(self):
def pack_hook(t):
return t
def unpack_hook(t):
return t
a = torch.tensor(2.0, requires_grad=True)
with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
b = a**2
c = b.exp()
d = c**2
pow_sv = b.grad_fn._raw_saved_self
exp_sv = c.grad_fn._raw_saved_result
pow2_sv = d.grad_fn._raw_saved_self
# Returns the packed object as-is
self.assertTrue(pow_sv.data is a)
self.assertTrue(pow_sv.unpack_hook is unpack_hook)
# Returns the detached data when the output/leaf is saved
self.assertFalse(exp_sv.data is c)
self.assertIsNone(exp_sv.unpack_hook)
# Returns the un-detached data when input is saved
self.assertTrue(pow2_sv.data is c)
self.assertIsNone(pow2_sv.unpack_hook)
def test_cant_create_saved_tensors(self):
with self.assertRaisesRegex(
RuntimeError,

View File

@ -601,6 +601,33 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) {
s.register_hooks(
std::make_unique<torch::autograd::PySavedVariableHooks>(
pack_hook, unpack_hook));
})
.def_property_readonly(
"data",
[](const torch::autograd::SavedVariable& s) -> py::object {
if (s.has_hooks()) {
auto opt = s.retrieve_unpack_hook_data();
TORCH_INTERNAL_ASSERT(opt.has_value());
py::gil_scoped_acquire gil;
const auto& [_unpack_fn, data_obj] = *opt;
PyObject* raw = data_obj.ptr(getPyInterpreter());
TORCH_INTERNAL_ASSERT(raw != nullptr);
return py::reinterpret_borrow<py::object>(raw);
} else {
return py::cast(s.get_raw_data().value());
}
})
.def_property_readonly(
"unpack_hook",
[](const torch::autograd::SavedVariable& s) -> py::object {
auto opt = s.retrieve_unpack_hook_data();
if (!opt.has_value()) {
return py::none();
}
py::gil_scoped_acquire gil;
const auto& [unpack_safe, _unused_data] = *opt;
auto* unpack_ptr = unpack_safe.ptr(getPyInterpreter());
return py::reinterpret_borrow<py::function>(unpack_ptr);
});
torch::autograd::profiler::python_tracer::init();

View File

@ -54,6 +54,14 @@ class TORCH_API SavedVariable {
return (bool)hooks_;
}
std::optional<at::Tensor> get_raw_data() const {
if (hooks_) {
return std::nullopt;
} else {
return data_;
}
}
// Used by compiled autograd
std::optional<std::pair<c10::SafePyObject, c10::SafePyObject>>
retrieve_unpack_hook_data() const {