[functorch] vjp now supports pytree inputs and outputs

This commit is contained in:
Richard Zou
2021-05-05 08:06:57 -07:00
committed by Jon Janzen
parent 9490aa0b65
commit f92fbeef74
3 changed files with 59 additions and 9 deletions

View File

@ -4,7 +4,7 @@ import collections
import torch.nn as nn
import torch.nn.functional as F
from torch.utils._pytree import tree_flatten, tree_unflatten
from .pytree_hacks import tree_map, tree_map_
from .pytree_hacks import tree_map, tree_map_, treespec_pprint
import gc
from .vmap import vmap
@ -57,7 +57,6 @@ def _as_tuple(val):
# Version of autograd.grad that handles outputs that don't depend on inputs
def _autograd_grad(outputs, inputs, grad_outputs=None, retain_graph=False, create_graph=True):
outputs = _as_tuple(outputs)
if grad_outputs is None:
diff_outputs = tuple(out for out in outputs if out.requires_grad)
else:
@ -85,15 +84,20 @@ def vjp(f, *primals):
primals_out = f(*diff_primals)
results = _undo_create_differentiable(primals_out, level)
flat_diff_primals, primals_spec = tree_flatten(diff_primals)
flat_primals_out, primals_out_spec = tree_flatten(_as_tuple(primals_out))
def wrapper(*cotangents, retain_graph=True, create_graph=True):
primals_out_tuple = _as_tuple(primals_out)
if len(primals_out_tuple) != len(cotangents):
flat_cotangents, cotangents_spec = tree_flatten(cotangents)
if primals_out_spec != cotangents_spec:
raise RuntimeError(
f'Got {len(primals_out_tuple)} outputs but {len(cotangents)} '
f'cotangents. These two quantities should be the same')
result = _autograd_grad(primals_out_tuple, diff_primals, cotangents,
f'Expected pytree structure of cotangents to be the same '
f'as pytree structure of outputs to the function. '
f'cotangents: {treespec_pprint(cotangents_spec)}, '
f'primal output: {treespec_pprint(primals_out_spec)}')
result = _autograd_grad(flat_primals_out, flat_diff_primals, flat_cotangents,
retain_graph=retain_graph, create_graph=create_graph)
return result
return tree_unflatten(result, primals_spec)
finally:
_grad_decrement_nesting()
@ -151,7 +155,8 @@ def grad_and_value(f, argnums=0, has_aux=False):
flat_diff_args, spec = tree_flatten(diff_args)
# NB: need create_graph so that backward pass isn't run in no_grad mode
flat_grad_input = _autograd_grad(output, flat_diff_args, create_graph=True)
flat_outputs = _as_tuple(output)
flat_grad_input = _autograd_grad(flat_outputs, flat_diff_args, create_graph=True)
grad_input = tree_unflatten(flat_grad_input, spec)
finally:

View File

@ -37,3 +37,13 @@ def tree_map_(fn_, pytree):
flat_args, _ = tree_flatten(pytree)
[fn_(arg) for arg in flat_args]
return pytree
class PlaceHolder():
def __repr__(self):
return '*'
def treespec_pprint(spec):
leafs = [PlaceHolder() for _ in range(spec.num_leaves)]
result = tree_unflatten(leafs, spec)
return repr(result)

View File

@ -339,6 +339,41 @@ class TestGradTransform(TestCase):
expected = torch.zeros(N, M, M, device=device)
self.assertEqual(result, expected)
def test_vjp_pytree_input(self, device):
def f(x):
return x[0] * x[1][0]
x = torch.randn([], device=device)
v = torch.randn([], device=device)
out, vjp_fn = vjp(f, (x, (x, x)))
self.assertEqual(out, x * x)
result = vjp_fn(v)
self.assertEqual(result, ((x * v, (x * v, 0.)),))
def test_vjp_pytree_output(self, device):
def f(x):
return x, (x, x)
x = torch.randn([], device=device)
v1 = torch.randn([], device=device)
v2 = torch.randn([], device=device)
v3 = torch.randn([], device=device)
_, vjp_fn = vjp(f, x)
result, = vjp_fn(v1, (v2, v3))
self.assertEqual(result, v1 + v2 + v3)
def test_vjp_pytree_error(self, device):
def f(x):
return x, (x, x)
x = torch.randn([], device=device)
v1 = torch.randn([], device=device)
v2 = torch.randn([], device=device)
v3 = torch.randn([], device=device)
_, vjp_fn = vjp(f, x)
with self.assertRaisesRegex(RuntimeError, 'Expected pytree structure'):
result, = vjp_fn((v1, (v2, v3)))
class TestVmapOfGrad(TestCase):
def test_per_sample_grads_inplace_view(self, device):