Fix module backward pre-hooks to actually update gradient (#97983)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97983
Approved by: https://github.com/albanD
This commit is contained in:
soulitzer
2023-03-30 12:30:19 -04:00
committed by PyTorch MergeBot
parent 06d677f41d
commit ee1c539ecf
2 changed files with 18 additions and 0 deletions

View File

@ -223,6 +223,11 @@ class BackwardHook:
raise RuntimeError("Backward hook for Modules where no input requires "
"gradient should always return None or None for all gradients.")
self.grad_outputs = None
if self.grad_outputs is not None:
assert self.output_tensors_index is not None # mypy
return tuple(self.grad_outputs[i] for i in self.output_tensors_index)
grad_fn.register_hook(hook)
is_tuple = True