diff --git a/test/test_autograd.py b/test/test_autograd.py index ee9605fc787f..2a00b620d78e 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -1151,7 +1151,6 @@ class TestAutograd(TestCase): # Incorrect case: grad_outputs wrong size out, tmp_edge = fn(x) - (tmp_grad,) = torch.autograd.grad(out, (tmp_edge,)) with self.assertRaisesRegex(RuntimeError, "Mismatch in shape"): torch.autograd.grad( tmp_edge, (x,), grad_outputs=torch.tensor([1.0, 2.0, 3.0, 4.0]) @@ -1167,6 +1166,32 @@ class TestAutograd(TestCase): grad_outputs=torch.rand_like(tmp_grad, dtype=torch.complex64), ) + # Run with .backward() and compare with .grad() + out, tmp_edge = fn(x) + torch.autograd.backward(tmp_edge, retain_graph=True) + (x_grad_ref,) = torch.autograd.grad(tmp_edge, (x,), retain_graph=True) + self.assertEqual(x.grad, x_grad_ref) + + # Pass a tuple of GradientEdges + x.grad = None + torch.autograd.backward((tmp_edge,), retain_graph=True) + self.assertEqual(x.grad, x_grad_ref) + + # Mixing GradientEdge and Tensors + out1, tmp_edge1 = fn(x) + out2, tmp_edge2 = fn(x) + (x_grad_ref,) = torch.autograd.grad((tmp_edge1, out2), (x,), retain_graph=True) + x.grad = None + torch.autograd.backward((tmp_edge1, out2), retain_graph=True) + self.assertEqual(x.grad, x_grad_ref) + + # .backward(): wrong shape + out, tmp_edge = fn(x) + with self.assertRaisesRegex(RuntimeError, "Mismatch in shape"): + torch.autograd.backward( + tmp_edge, inputs=(x,), grad_tensors=torch.tensor([1.0, 2.0, 3.0, 4.0]) + ) + def test_grad_nonleaf(self): x_init = torch.randn(2, 2, requires_grad=True) x = x_init diff --git a/torch/autograd/__init__.py b/torch/autograd/__init__.py index f51456ef15ce..46b047db40c1 100644 --- a/torch/autograd/__init__.py +++ b/torch/autograd/__init__.py @@ -240,7 +240,7 @@ def _tensor_or_tensors_to_tuple( def backward( - tensors: _TensorOrTensors, + tensors: _TensorOrTensorsOrGradEdge, grad_tensors: Optional[_TensorOrTensors] = None, retain_graph: Optional[bool] = None, create_graph: bool = False, @@ -284,8 +284,8 @@ def backward( See https://github.com/pytorch/pytorch/pull/60521#issuecomment-867061780 for more details. Args: - tensors (Sequence[Tensor] or Tensor): Tensors of which the derivative will be - computed. + tensors (Sequence[Tensor] or Tensor or Sequence[GradientEdge] or GradientEdge): Tensors of which + the derivative will be computed. grad_tensors (Sequence[Tensor or None] or Tensor, optional): The "vector" in the Jacobian-vector product, usually gradients w.r.t. each element of corresponding tensors. None values can be specified for scalar Tensors or @@ -327,7 +327,12 @@ def backward( if inputs is not None and len(inputs) == 0: raise RuntimeError("`inputs` argument to `backward()` cannot be empty.") - tensors = (tensors,) if isinstance(tensors, torch.Tensor) else tuple(tensors) + if is_tensor_like(tensors) or isinstance(tensors, graph.GradientEdge): + tensors = cast( + Union[Tuple[torch.Tensor], Tuple[graph.GradientEdge]], (tensors,) + ) + else: + tensors = tuple(tensors) inputs = ( (inputs,) if isinstance(inputs, (torch.Tensor, graph.GradientEdge))