mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[functorch] vjp now supports pytree inputs and outputs
This commit is contained in:
@ -4,7 +4,7 @@ import collections
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.utils._pytree import tree_flatten, tree_unflatten
|
||||
from .pytree_hacks import tree_map, tree_map_
|
||||
from .pytree_hacks import tree_map, tree_map_, treespec_pprint
|
||||
import gc
|
||||
|
||||
from .vmap import vmap
|
||||
@ -57,7 +57,6 @@ def _as_tuple(val):
|
||||
|
||||
# Version of autograd.grad that handles outputs that don't depend on inputs
|
||||
def _autograd_grad(outputs, inputs, grad_outputs=None, retain_graph=False, create_graph=True):
|
||||
outputs = _as_tuple(outputs)
|
||||
if grad_outputs is None:
|
||||
diff_outputs = tuple(out for out in outputs if out.requires_grad)
|
||||
else:
|
||||
@ -85,15 +84,20 @@ def vjp(f, *primals):
|
||||
primals_out = f(*diff_primals)
|
||||
results = _undo_create_differentiable(primals_out, level)
|
||||
|
||||
flat_diff_primals, primals_spec = tree_flatten(diff_primals)
|
||||
flat_primals_out, primals_out_spec = tree_flatten(_as_tuple(primals_out))
|
||||
|
||||
def wrapper(*cotangents, retain_graph=True, create_graph=True):
|
||||
primals_out_tuple = _as_tuple(primals_out)
|
||||
if len(primals_out_tuple) != len(cotangents):
|
||||
flat_cotangents, cotangents_spec = tree_flatten(cotangents)
|
||||
if primals_out_spec != cotangents_spec:
|
||||
raise RuntimeError(
|
||||
f'Got {len(primals_out_tuple)} outputs but {len(cotangents)} '
|
||||
f'cotangents. These two quantities should be the same')
|
||||
result = _autograd_grad(primals_out_tuple, diff_primals, cotangents,
|
||||
f'Expected pytree structure of cotangents to be the same '
|
||||
f'as pytree structure of outputs to the function. '
|
||||
f'cotangents: {treespec_pprint(cotangents_spec)}, '
|
||||
f'primal output: {treespec_pprint(primals_out_spec)}')
|
||||
result = _autograd_grad(flat_primals_out, flat_diff_primals, flat_cotangents,
|
||||
retain_graph=retain_graph, create_graph=create_graph)
|
||||
return result
|
||||
return tree_unflatten(result, primals_spec)
|
||||
|
||||
finally:
|
||||
_grad_decrement_nesting()
|
||||
@ -151,7 +155,8 @@ def grad_and_value(f, argnums=0, has_aux=False):
|
||||
flat_diff_args, spec = tree_flatten(diff_args)
|
||||
|
||||
# NB: need create_graph so that backward pass isn't run in no_grad mode
|
||||
flat_grad_input = _autograd_grad(output, flat_diff_args, create_graph=True)
|
||||
flat_outputs = _as_tuple(output)
|
||||
flat_grad_input = _autograd_grad(flat_outputs, flat_diff_args, create_graph=True)
|
||||
grad_input = tree_unflatten(flat_grad_input, spec)
|
||||
|
||||
finally:
|
||||
|
@ -37,3 +37,13 @@ def tree_map_(fn_, pytree):
|
||||
flat_args, _ = tree_flatten(pytree)
|
||||
[fn_(arg) for arg in flat_args]
|
||||
return pytree
|
||||
|
||||
class PlaceHolder():
|
||||
def __repr__(self):
|
||||
return '*'
|
||||
|
||||
def treespec_pprint(spec):
|
||||
leafs = [PlaceHolder() for _ in range(spec.num_leaves)]
|
||||
result = tree_unflatten(leafs, spec)
|
||||
return repr(result)
|
||||
|
||||
|
@ -339,6 +339,41 @@ class TestGradTransform(TestCase):
|
||||
expected = torch.zeros(N, M, M, device=device)
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_vjp_pytree_input(self, device):
|
||||
def f(x):
|
||||
return x[0] * x[1][0]
|
||||
|
||||
x = torch.randn([], device=device)
|
||||
v = torch.randn([], device=device)
|
||||
out, vjp_fn = vjp(f, (x, (x, x)))
|
||||
self.assertEqual(out, x * x)
|
||||
result = vjp_fn(v)
|
||||
self.assertEqual(result, ((x * v, (x * v, 0.)),))
|
||||
|
||||
def test_vjp_pytree_output(self, device):
|
||||
def f(x):
|
||||
return x, (x, x)
|
||||
|
||||
x = torch.randn([], device=device)
|
||||
v1 = torch.randn([], device=device)
|
||||
v2 = torch.randn([], device=device)
|
||||
v3 = torch.randn([], device=device)
|
||||
_, vjp_fn = vjp(f, x)
|
||||
result, = vjp_fn(v1, (v2, v3))
|
||||
self.assertEqual(result, v1 + v2 + v3)
|
||||
|
||||
def test_vjp_pytree_error(self, device):
|
||||
def f(x):
|
||||
return x, (x, x)
|
||||
|
||||
x = torch.randn([], device=device)
|
||||
v1 = torch.randn([], device=device)
|
||||
v2 = torch.randn([], device=device)
|
||||
v3 = torch.randn([], device=device)
|
||||
_, vjp_fn = vjp(f, x)
|
||||
with self.assertRaisesRegex(RuntimeError, 'Expected pytree structure'):
|
||||
result, = vjp_fn((v1, (v2, v3)))
|
||||
|
||||
|
||||
class TestVmapOfGrad(TestCase):
|
||||
def test_per_sample_grads_inplace_view(self, device):
|
||||
|
Reference in New Issue
Block a user