Properly move retains_grad hook on in-place over view for base (#117552)

Fixes https://github.com/pytorch/pytorch/issues/117366
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117552
Approved by: https://github.com/albanD
This commit is contained in:
soulitzer
2024-01-24 16:44:19 -05:00
committed by PyTorch MergeBot
parent 9c1348feb3
commit 5b819d9ef0
2 changed files with 23 additions and 0 deletions

View File

@ -239,6 +239,11 @@ void rebase_history(const Variable& self, Edge gradient_edge) {
at::TensorGeometry(self),
view_info.view_fn_,
std::move(gradient_edge.function));
if (self.requires_grad()) {
// If self did not previously require grad, there are no hooks to move
torch::autograd::impl::update_tensor_hooks_on_new_gradfn(
view_info.base_, view_info.base_.grad_fn(), copy_slices);
}
set_gradient_edge(view_info.base_, {std::move(copy_slices), 0});
self.grad_fn(); // trigger an update to the view's grad_fn
return;