mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	Allow GradientEdge as torch.autograd.backward outputs (#144744)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144744 Approved by: https://github.com/albanD
This commit is contained in:
		
				
					committed by
					
						 PyTorch MergeBot
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							64829b356a
						
					
				
				
					commit
					c000214826
				
			| @ -1151,7 +1151,6 @@ class TestAutograd(TestCase): | ||||
|  | ||||
|         # Incorrect case: grad_outputs wrong size | ||||
|         out, tmp_edge = fn(x) | ||||
|         (tmp_grad,) = torch.autograd.grad(out, (tmp_edge,)) | ||||
|         with self.assertRaisesRegex(RuntimeError, "Mismatch in shape"): | ||||
|             torch.autograd.grad( | ||||
|                 tmp_edge, (x,), grad_outputs=torch.tensor([1.0, 2.0, 3.0, 4.0]) | ||||
| @ -1167,6 +1166,32 @@ class TestAutograd(TestCase): | ||||
|                 grad_outputs=torch.rand_like(tmp_grad, dtype=torch.complex64), | ||||
|             ) | ||||
|  | ||||
|         # Run with .backward() and compare with .grad() | ||||
|         out, tmp_edge = fn(x) | ||||
|         torch.autograd.backward(tmp_edge, retain_graph=True) | ||||
|         (x_grad_ref,) = torch.autograd.grad(tmp_edge, (x,), retain_graph=True) | ||||
|         self.assertEqual(x.grad, x_grad_ref) | ||||
|  | ||||
|         # Pass a tuple of GradientEdges | ||||
|         x.grad = None | ||||
|         torch.autograd.backward((tmp_edge,), retain_graph=True) | ||||
|         self.assertEqual(x.grad, x_grad_ref) | ||||
|  | ||||
|         # Mixing GradientEdge and Tensors | ||||
|         out1, tmp_edge1 = fn(x) | ||||
|         out2, tmp_edge2 = fn(x) | ||||
|         (x_grad_ref,) = torch.autograd.grad((tmp_edge1, out2), (x,), retain_graph=True) | ||||
|         x.grad = None | ||||
|         torch.autograd.backward((tmp_edge1, out2), retain_graph=True) | ||||
|         self.assertEqual(x.grad, x_grad_ref) | ||||
|  | ||||
|         # .backward(): wrong shape | ||||
|         out, tmp_edge = fn(x) | ||||
|         with self.assertRaisesRegex(RuntimeError, "Mismatch in shape"): | ||||
|             torch.autograd.backward( | ||||
|                 tmp_edge, inputs=(x,), grad_tensors=torch.tensor([1.0, 2.0, 3.0, 4.0]) | ||||
|             ) | ||||
|  | ||||
|     def test_grad_nonleaf(self): | ||||
|         x_init = torch.randn(2, 2, requires_grad=True) | ||||
|         x = x_init | ||||
|  | ||||
| @ -240,7 +240,7 @@ def _tensor_or_tensors_to_tuple( | ||||
|  | ||||
|  | ||||
| def backward( | ||||
|     tensors: _TensorOrTensors, | ||||
|     tensors: _TensorOrTensorsOrGradEdge, | ||||
|     grad_tensors: Optional[_TensorOrTensors] = None, | ||||
|     retain_graph: Optional[bool] = None, | ||||
|     create_graph: bool = False, | ||||
| @ -284,8 +284,8 @@ def backward( | ||||
|         See https://github.com/pytorch/pytorch/pull/60521#issuecomment-867061780 for more details. | ||||
|  | ||||
|     Args: | ||||
|         tensors (Sequence[Tensor] or Tensor): Tensors of which the derivative will be | ||||
|             computed. | ||||
|         tensors (Sequence[Tensor] or Tensor or Sequence[GradientEdge] or GradientEdge): Tensors of which | ||||
|             the derivative will be computed. | ||||
|         grad_tensors (Sequence[Tensor or None] or Tensor, optional): The "vector" in | ||||
|             the Jacobian-vector product, usually gradients w.r.t. each element of | ||||
|             corresponding tensors. None values can be specified for scalar Tensors or | ||||
| @ -327,7 +327,12 @@ def backward( | ||||
|     if inputs is not None and len(inputs) == 0: | ||||
|         raise RuntimeError("`inputs` argument to `backward()` cannot be empty.") | ||||
|  | ||||
|     tensors = (tensors,) if isinstance(tensors, torch.Tensor) else tuple(tensors) | ||||
|     if is_tensor_like(tensors) or isinstance(tensors, graph.GradientEdge): | ||||
|         tensors = cast( | ||||
|             Union[Tuple[torch.Tensor], Tuple[graph.GradientEdge]], (tensors,) | ||||
|         ) | ||||
|     else: | ||||
|         tensors = tuple(tensors) | ||||
|     inputs = ( | ||||
|         (inputs,) | ||||
|         if isinstance(inputs, (torch.Tensor, graph.GradientEdge)) | ||||
|  | ||||
		Reference in New Issue
	
	Block a user