mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Allow GradientEdge as torch.autograd.backward outputs (#144744)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144744 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
64829b356a
commit
c000214826
@ -1151,7 +1151,6 @@ class TestAutograd(TestCase):
|
||||
|
||||
# Incorrect case: grad_outputs wrong size
|
||||
out, tmp_edge = fn(x)
|
||||
(tmp_grad,) = torch.autograd.grad(out, (tmp_edge,))
|
||||
with self.assertRaisesRegex(RuntimeError, "Mismatch in shape"):
|
||||
torch.autograd.grad(
|
||||
tmp_edge, (x,), grad_outputs=torch.tensor([1.0, 2.0, 3.0, 4.0])
|
||||
@ -1167,6 +1166,32 @@ class TestAutograd(TestCase):
|
||||
grad_outputs=torch.rand_like(tmp_grad, dtype=torch.complex64),
|
||||
)
|
||||
|
||||
# Run with .backward() and compare with .grad()
|
||||
out, tmp_edge = fn(x)
|
||||
torch.autograd.backward(tmp_edge, retain_graph=True)
|
||||
(x_grad_ref,) = torch.autograd.grad(tmp_edge, (x,), retain_graph=True)
|
||||
self.assertEqual(x.grad, x_grad_ref)
|
||||
|
||||
# Pass a tuple of GradientEdges
|
||||
x.grad = None
|
||||
torch.autograd.backward((tmp_edge,), retain_graph=True)
|
||||
self.assertEqual(x.grad, x_grad_ref)
|
||||
|
||||
# Mixing GradientEdge and Tensors
|
||||
out1, tmp_edge1 = fn(x)
|
||||
out2, tmp_edge2 = fn(x)
|
||||
(x_grad_ref,) = torch.autograd.grad((tmp_edge1, out2), (x,), retain_graph=True)
|
||||
x.grad = None
|
||||
torch.autograd.backward((tmp_edge1, out2), retain_graph=True)
|
||||
self.assertEqual(x.grad, x_grad_ref)
|
||||
|
||||
# .backward(): wrong shape
|
||||
out, tmp_edge = fn(x)
|
||||
with self.assertRaisesRegex(RuntimeError, "Mismatch in shape"):
|
||||
torch.autograd.backward(
|
||||
tmp_edge, inputs=(x,), grad_tensors=torch.tensor([1.0, 2.0, 3.0, 4.0])
|
||||
)
|
||||
|
||||
def test_grad_nonleaf(self):
|
||||
x_init = torch.randn(2, 2, requires_grad=True)
|
||||
x = x_init
|
||||
|
@ -240,7 +240,7 @@ def _tensor_or_tensors_to_tuple(
|
||||
|
||||
|
||||
def backward(
|
||||
tensors: _TensorOrTensors,
|
||||
tensors: _TensorOrTensorsOrGradEdge,
|
||||
grad_tensors: Optional[_TensorOrTensors] = None,
|
||||
retain_graph: Optional[bool] = None,
|
||||
create_graph: bool = False,
|
||||
@ -284,8 +284,8 @@ def backward(
|
||||
See https://github.com/pytorch/pytorch/pull/60521#issuecomment-867061780 for more details.
|
||||
|
||||
Args:
|
||||
tensors (Sequence[Tensor] or Tensor): Tensors of which the derivative will be
|
||||
computed.
|
||||
tensors (Sequence[Tensor] or Tensor or Sequence[GradientEdge] or GradientEdge): Tensors of which
|
||||
the derivative will be computed.
|
||||
grad_tensors (Sequence[Tensor or None] or Tensor, optional): The "vector" in
|
||||
the Jacobian-vector product, usually gradients w.r.t. each element of
|
||||
corresponding tensors. None values can be specified for scalar Tensors or
|
||||
@ -327,7 +327,12 @@ def backward(
|
||||
if inputs is not None and len(inputs) == 0:
|
||||
raise RuntimeError("`inputs` argument to `backward()` cannot be empty.")
|
||||
|
||||
tensors = (tensors,) if isinstance(tensors, torch.Tensor) else tuple(tensors)
|
||||
if is_tensor_like(tensors) or isinstance(tensors, graph.GradientEdge):
|
||||
tensors = cast(
|
||||
Union[Tuple[torch.Tensor], Tuple[graph.GradientEdge]], (tensors,)
|
||||
)
|
||||
else:
|
||||
tensors = tuple(tensors)
|
||||
inputs = (
|
||||
(inputs,)
|
||||
if isinstance(inputs, (torch.Tensor, graph.GradientEdge))
|
||||
|
Reference in New Issue
Block a user