Fix module pre bw hooks when input doesn't req grad but gradients are changed by the user (#116454)

As per title.

FYI @vkuzo
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116454
Approved by: https://github.com/mikaylagawarecki
This commit is contained in:
albanD
2023-12-28 18:32:50 +00:00
committed by PyTorch MergeBot
parent fb91acd33b
commit f10c3f4184
2 changed files with 19 additions and 10 deletions

View File

@ -281,17 +281,23 @@ class TestModuleHooks(TestCase):
self.assertEqual(fired_hooks, expected + expected)
# Backward pre hook can affect subsequent gradient computation
a = torch.ones(2, requires_grad=True)
model = nn.Linear(2, 2)
for rg in [True, False]:
a = torch.ones(2, requires_grad=rg)
model = nn.Linear(2, 2)
def fn(_unused_module, grad_output):
return (grad_output[0] * 0,)
def fn(_unused_module, grad_output):
return (grad_output[0] * 0,)
model.register_full_backward_pre_hook(fn)
model.register_full_backward_pre_hook(fn)
out = model(a)
out.sum().backward()
self.assertEqual(model.weight.grad, torch.zeros(2, 2))
if rg:
self.assertEqual(a.grad, torch.zeros_like(a))
else:
self.assertIsNone(a.grad)
out = model(a)
out.sum().backward()
self.assertEqual(a.grad, torch.zeros_like(a))
@parametrize_test("named_tuple", (True, False))
def test_mixed_hooks(self, named_tuple):

View File

@ -218,6 +218,9 @@ class BackwardHook:
f"got {actual_len}, but expected {expected_len}")
self.grad_outputs = hook_grad_outputs
# We need to be able to clear self.grad_outputs but also return it
local_grad_outputs = self.grad_outputs
# Special case if no input required gradients, this hook should call the user
# hook directly
if self.input_tensors_index is None:
@ -229,9 +232,9 @@ class BackwardHook:
"gradient should always return None or None for all gradients.")
self.grad_outputs = None
if self.grad_outputs is not None:
if local_grad_outputs is not None:
assert self.output_tensors_index is not None # mypy
return tuple(self.grad_outputs[i] for i in self.output_tensors_index)
return tuple(local_grad_outputs[i] for i in self.output_tensors_index)
grad_fn.register_hook(hook)