Allow GradientEdge as torch.autograd.backward outputs (#144744)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144744
Approved by: https://github.com/albanD
This commit is contained in:
soulitzer
2025-01-14 11:55:32 -05:00
committed by PyTorch MergeBot
parent 64829b356a
commit c000214826
2 changed files with 35 additions and 5 deletions

View File

@ -1151,7 +1151,6 @@ class TestAutograd(TestCase):
# Incorrect case: grad_outputs wrong size
out, tmp_edge = fn(x)
(tmp_grad,) = torch.autograd.grad(out, (tmp_edge,))
with self.assertRaisesRegex(RuntimeError, "Mismatch in shape"):
torch.autograd.grad(
tmp_edge, (x,), grad_outputs=torch.tensor([1.0, 2.0, 3.0, 4.0])
@ -1167,6 +1166,32 @@ class TestAutograd(TestCase):
grad_outputs=torch.rand_like(tmp_grad, dtype=torch.complex64),
)
# Run with .backward() and compare with .grad()
out, tmp_edge = fn(x)
torch.autograd.backward(tmp_edge, retain_graph=True)
(x_grad_ref,) = torch.autograd.grad(tmp_edge, (x,), retain_graph=True)
self.assertEqual(x.grad, x_grad_ref)
# Pass a tuple of GradientEdges
x.grad = None
torch.autograd.backward((tmp_edge,), retain_graph=True)
self.assertEqual(x.grad, x_grad_ref)
# Mixing GradientEdge and Tensors
out1, tmp_edge1 = fn(x)
out2, tmp_edge2 = fn(x)
(x_grad_ref,) = torch.autograd.grad((tmp_edge1, out2), (x,), retain_graph=True)
x.grad = None
torch.autograd.backward((tmp_edge1, out2), retain_graph=True)
self.assertEqual(x.grad, x_grad_ref)
# .backward(): wrong shape
out, tmp_edge = fn(x)
with self.assertRaisesRegex(RuntimeError, "Mismatch in shape"):
torch.autograd.backward(
tmp_edge, inputs=(x,), grad_tensors=torch.tensor([1.0, 2.0, 3.0, 4.0])
)
def test_grad_nonleaf(self):
x_init = torch.randn(2, 2, requires_grad=True)
x = x_init

View File

@ -240,7 +240,7 @@ def _tensor_or_tensors_to_tuple(
def backward(
tensors: _TensorOrTensors,
tensors: _TensorOrTensorsOrGradEdge,
grad_tensors: Optional[_TensorOrTensors] = None,
retain_graph: Optional[bool] = None,
create_graph: bool = False,
@ -284,8 +284,8 @@ def backward(
See https://github.com/pytorch/pytorch/pull/60521#issuecomment-867061780 for more details.
Args:
tensors (Sequence[Tensor] or Tensor): Tensors of which the derivative will be
computed.
tensors (Sequence[Tensor] or Tensor or Sequence[GradientEdge] or GradientEdge): Tensors of which
the derivative will be computed.
grad_tensors (Sequence[Tensor or None] or Tensor, optional): The "vector" in
the Jacobian-vector product, usually gradients w.r.t. each element of
corresponding tensors. None values can be specified for scalar Tensors or
@ -327,7 +327,12 @@ def backward(
if inputs is not None and len(inputs) == 0:
raise RuntimeError("`inputs` argument to `backward()` cannot be empty.")
tensors = (tensors,) if isinstance(tensors, torch.Tensor) else tuple(tensors)
if is_tensor_like(tensors) or isinstance(tensors, graph.GradientEdge):
tensors = cast(
Union[Tuple[torch.Tensor], Tuple[graph.GradientEdge]], (tensors,)
)
else:
tensors = tuple(tensors)
inputs = (
(inputs,)
if isinstance(inputs, (torch.Tensor, graph.GradientEdge))