mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix module pre bw hooks when input doesn't req grad but gradients are changed by the user (#116454)
As per title. FYI @vkuzo Pull Request resolved: https://github.com/pytorch/pytorch/pull/116454 Approved by: https://github.com/mikaylagawarecki
This commit is contained in:
@ -281,17 +281,23 @@ class TestModuleHooks(TestCase):
|
||||
self.assertEqual(fired_hooks, expected + expected)
|
||||
|
||||
# Backward pre hook can affect subsequent gradient computation
|
||||
a = torch.ones(2, requires_grad=True)
|
||||
model = nn.Linear(2, 2)
|
||||
for rg in [True, False]:
|
||||
a = torch.ones(2, requires_grad=rg)
|
||||
model = nn.Linear(2, 2)
|
||||
|
||||
def fn(_unused_module, grad_output):
|
||||
return (grad_output[0] * 0,)
|
||||
def fn(_unused_module, grad_output):
|
||||
return (grad_output[0] * 0,)
|
||||
|
||||
model.register_full_backward_pre_hook(fn)
|
||||
model.register_full_backward_pre_hook(fn)
|
||||
|
||||
out = model(a)
|
||||
out.sum().backward()
|
||||
self.assertEqual(model.weight.grad, torch.zeros(2, 2))
|
||||
if rg:
|
||||
self.assertEqual(a.grad, torch.zeros_like(a))
|
||||
else:
|
||||
self.assertIsNone(a.grad)
|
||||
|
||||
out = model(a)
|
||||
out.sum().backward()
|
||||
self.assertEqual(a.grad, torch.zeros_like(a))
|
||||
|
||||
@parametrize_test("named_tuple", (True, False))
|
||||
def test_mixed_hooks(self, named_tuple):
|
||||
|
@ -218,6 +218,9 @@ class BackwardHook:
|
||||
f"got {actual_len}, but expected {expected_len}")
|
||||
self.grad_outputs = hook_grad_outputs
|
||||
|
||||
# We need to be able to clear self.grad_outputs but also return it
|
||||
local_grad_outputs = self.grad_outputs
|
||||
|
||||
# Special case if no input required gradients, this hook should call the user
|
||||
# hook directly
|
||||
if self.input_tensors_index is None:
|
||||
@ -229,9 +232,9 @@ class BackwardHook:
|
||||
"gradient should always return None or None for all gradients.")
|
||||
self.grad_outputs = None
|
||||
|
||||
if self.grad_outputs is not None:
|
||||
if local_grad_outputs is not None:
|
||||
assert self.output_tensors_index is not None # mypy
|
||||
return tuple(self.grad_outputs[i] for i in self.output_tensors_index)
|
||||
return tuple(local_grad_outputs[i] for i in self.output_tensors_index)
|
||||
|
||||
grad_fn.register_hook(hook)
|
||||
|
||||
|
Reference in New Issue
Block a user