mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Allow GradientEdge as torch.autograd.backward outputs (#144744)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144744 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
64829b356a
commit
c000214826
@ -1151,7 +1151,6 @@ class TestAutograd(TestCase):
|
||||
|
||||
# Incorrect case: grad_outputs wrong size
|
||||
out, tmp_edge = fn(x)
|
||||
(tmp_grad,) = torch.autograd.grad(out, (tmp_edge,))
|
||||
with self.assertRaisesRegex(RuntimeError, "Mismatch in shape"):
|
||||
torch.autograd.grad(
|
||||
tmp_edge, (x,), grad_outputs=torch.tensor([1.0, 2.0, 3.0, 4.0])
|
||||
@ -1167,6 +1166,32 @@ class TestAutograd(TestCase):
|
||||
grad_outputs=torch.rand_like(tmp_grad, dtype=torch.complex64),
|
||||
)
|
||||
|
||||
# Run with .backward() and compare with .grad()
|
||||
out, tmp_edge = fn(x)
|
||||
torch.autograd.backward(tmp_edge, retain_graph=True)
|
||||
(x_grad_ref,) = torch.autograd.grad(tmp_edge, (x,), retain_graph=True)
|
||||
self.assertEqual(x.grad, x_grad_ref)
|
||||
|
||||
# Pass a tuple of GradientEdges
|
||||
x.grad = None
|
||||
torch.autograd.backward((tmp_edge,), retain_graph=True)
|
||||
self.assertEqual(x.grad, x_grad_ref)
|
||||
|
||||
# Mixing GradientEdge and Tensors
|
||||
out1, tmp_edge1 = fn(x)
|
||||
out2, tmp_edge2 = fn(x)
|
||||
(x_grad_ref,) = torch.autograd.grad((tmp_edge1, out2), (x,), retain_graph=True)
|
||||
x.grad = None
|
||||
torch.autograd.backward((tmp_edge1, out2), retain_graph=True)
|
||||
self.assertEqual(x.grad, x_grad_ref)
|
||||
|
||||
# .backward(): wrong shape
|
||||
out, tmp_edge = fn(x)
|
||||
with self.assertRaisesRegex(RuntimeError, "Mismatch in shape"):
|
||||
torch.autograd.backward(
|
||||
tmp_edge, inputs=(x,), grad_tensors=torch.tensor([1.0, 2.0, 3.0, 4.0])
|
||||
)
|
||||
|
||||
def test_grad_nonleaf(self):
|
||||
x_init = torch.randn(2, 2, requires_grad=True)
|
||||
x = x_init
|
||||
|
Reference in New Issue
Block a user