Allow GradientEdge as torch.autograd.backward outputs (#144744)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144744
Approved by: https://github.com/albanD
This commit is contained in:
soulitzer
2025-01-14 11:55:32 -05:00
committed by PyTorch MergeBot
parent 64829b356a
commit c000214826
2 changed files with 35 additions and 5 deletions

View File

@ -1151,7 +1151,6 @@ class TestAutograd(TestCase):
# Incorrect case: grad_outputs wrong size
out, tmp_edge = fn(x)
(tmp_grad,) = torch.autograd.grad(out, (tmp_edge,))
with self.assertRaisesRegex(RuntimeError, "Mismatch in shape"):
torch.autograd.grad(
tmp_edge, (x,), grad_outputs=torch.tensor([1.0, 2.0, 3.0, 4.0])
@ -1167,6 +1166,32 @@ class TestAutograd(TestCase):
grad_outputs=torch.rand_like(tmp_grad, dtype=torch.complex64),
)
# Run with .backward() and compare with .grad()
out, tmp_edge = fn(x)
torch.autograd.backward(tmp_edge, retain_graph=True)
(x_grad_ref,) = torch.autograd.grad(tmp_edge, (x,), retain_graph=True)
self.assertEqual(x.grad, x_grad_ref)
# Pass a tuple of GradientEdges
x.grad = None
torch.autograd.backward((tmp_edge,), retain_graph=True)
self.assertEqual(x.grad, x_grad_ref)
# Mixing GradientEdge and Tensors
out1, tmp_edge1 = fn(x)
out2, tmp_edge2 = fn(x)
(x_grad_ref,) = torch.autograd.grad((tmp_edge1, out2), (x,), retain_graph=True)
x.grad = None
torch.autograd.backward((tmp_edge1, out2), retain_graph=True)
self.assertEqual(x.grad, x_grad_ref)
# .backward(): wrong shape
out, tmp_edge = fn(x)
with self.assertRaisesRegex(RuntimeError, "Mismatch in shape"):
torch.autograd.backward(
tmp_edge, inputs=(x,), grad_tensors=torch.tensor([1.0, 2.0, 3.0, 4.0])
)
def test_grad_nonleaf(self):
x_init = torch.randn(2, 2, requires_grad=True)
x = x_init