diff --git a/functorch/functorch/_src/eager_transforms.py b/functorch/functorch/_src/eager_transforms.py index a51a61ec5dbf..17fb01d23783 100644 --- a/functorch/functorch/_src/eager_transforms.py +++ b/functorch/functorch/_src/eager_transforms.py @@ -4,7 +4,7 @@ import collections import torch.nn as nn import torch.nn.functional as F from torch.utils._pytree import tree_flatten, tree_unflatten -from .pytree_hacks import tree_map, tree_map_ +from .pytree_hacks import tree_map, tree_map_, treespec_pprint import gc from .vmap import vmap @@ -57,7 +57,6 @@ def _as_tuple(val): # Version of autograd.grad that handles outputs that don't depend on inputs def _autograd_grad(outputs, inputs, grad_outputs=None, retain_graph=False, create_graph=True): - outputs = _as_tuple(outputs) if grad_outputs is None: diff_outputs = tuple(out for out in outputs if out.requires_grad) else: @@ -85,15 +84,20 @@ def vjp(f, *primals): primals_out = f(*diff_primals) results = _undo_create_differentiable(primals_out, level) + flat_diff_primals, primals_spec = tree_flatten(diff_primals) + flat_primals_out, primals_out_spec = tree_flatten(_as_tuple(primals_out)) + def wrapper(*cotangents, retain_graph=True, create_graph=True): - primals_out_tuple = _as_tuple(primals_out) - if len(primals_out_tuple) != len(cotangents): + flat_cotangents, cotangents_spec = tree_flatten(cotangents) + if primals_out_spec != cotangents_spec: raise RuntimeError( - f'Got {len(primals_out_tuple)} outputs but {len(cotangents)} ' - f'cotangents. These two quantities should be the same') - result = _autograd_grad(primals_out_tuple, diff_primals, cotangents, + f'Expected pytree structure of cotangents to be the same ' + f'as pytree structure of outputs to the function. ' + f'cotangents: {treespec_pprint(cotangents_spec)}, ' + f'primal output: {treespec_pprint(primals_out_spec)}') + result = _autograd_grad(flat_primals_out, flat_diff_primals, flat_cotangents, retain_graph=retain_graph, create_graph=create_graph) - return result + return tree_unflatten(result, primals_spec) finally: _grad_decrement_nesting() @@ -151,7 +155,8 @@ def grad_and_value(f, argnums=0, has_aux=False): flat_diff_args, spec = tree_flatten(diff_args) # NB: need create_graph so that backward pass isn't run in no_grad mode - flat_grad_input = _autograd_grad(output, flat_diff_args, create_graph=True) + flat_outputs = _as_tuple(output) + flat_grad_input = _autograd_grad(flat_outputs, flat_diff_args, create_graph=True) grad_input = tree_unflatten(flat_grad_input, spec) finally: diff --git a/functorch/functorch/_src/pytree_hacks.py b/functorch/functorch/_src/pytree_hacks.py index 2aef9f61cae6..5cf9bef22374 100644 --- a/functorch/functorch/_src/pytree_hacks.py +++ b/functorch/functorch/_src/pytree_hacks.py @@ -37,3 +37,13 @@ def tree_map_(fn_, pytree): flat_args, _ = tree_flatten(pytree) [fn_(arg) for arg in flat_args] return pytree + +class PlaceHolder(): + def __repr__(self): + return '*' + +def treespec_pprint(spec): + leafs = [PlaceHolder() for _ in range(spec.num_leaves)] + result = tree_unflatten(leafs, spec) + return repr(result) + diff --git a/functorch/test/test_eager_transforms.py b/functorch/test/test_eager_transforms.py index 03964e42d96c..05b98c22c039 100644 --- a/functorch/test/test_eager_transforms.py +++ b/functorch/test/test_eager_transforms.py @@ -339,6 +339,41 @@ class TestGradTransform(TestCase): expected = torch.zeros(N, M, M, device=device) self.assertEqual(result, expected) + def test_vjp_pytree_input(self, device): + def f(x): + return x[0] * x[1][0] + + x = torch.randn([], device=device) + v = torch.randn([], device=device) + out, vjp_fn = vjp(f, (x, (x, x))) + self.assertEqual(out, x * x) + result = vjp_fn(v) + self.assertEqual(result, ((x * v, (x * v, 0.)),)) + + def test_vjp_pytree_output(self, device): + def f(x): + return x, (x, x) + + x = torch.randn([], device=device) + v1 = torch.randn([], device=device) + v2 = torch.randn([], device=device) + v3 = torch.randn([], device=device) + _, vjp_fn = vjp(f, x) + result, = vjp_fn(v1, (v2, v3)) + self.assertEqual(result, v1 + v2 + v3) + + def test_vjp_pytree_error(self, device): + def f(x): + return x, (x, x) + + x = torch.randn([], device=device) + v1 = torch.randn([], device=device) + v2 = torch.randn([], device=device) + v3 = torch.randn([], device=device) + _, vjp_fn = vjp(f, x) + with self.assertRaisesRegex(RuntimeError, 'Expected pytree structure'): + result, = vjp_fn((v1, (v2, v3))) + class TestVmapOfGrad(TestCase): def test_per_sample_grads_inplace_view(self, device):