[functorch] Add pytree support for grad and grad_and_value

This commit is contained in:
Richard Zou
2021-04-30 07:51:02 -07:00
committed by Jon Janzen
parent 8c8685ca39
commit c6773c67d6
3 changed files with 91 additions and 42 deletions

View File

@ -2,7 +2,7 @@ import torch
from . import _C
from ._src.vmap import vmap
from ._src.eager_transforms import grad, grad_with_value, vjp, jacrev
from ._src.eager_transforms import grad, grad_and_value, vjp, jacrev
from ._src.make_functional import make_functional, make_functional_with_buffers
# Monkeypatching lol

View File

@ -3,6 +3,7 @@ from functools import partial, wraps
import collections
import torch.nn as nn
import torch.nn.functional as F
from torch.utils._pytree import tree_flatten, tree_unflatten
import gc
from .vmap import vmap
@ -15,6 +16,16 @@ from functorch._C import (
_grad_decrement_nesting,
)
# TODO: replace this with tree_map from core
def tree_map(fn, pytree):
flat_args, spec = tree_flatten(pytree)
return tree_unflatten([fn(arg) for arg in flat_args], spec)
def tree_map_(fn_, pytree):
flat_args, _ = tree_flatten(pytree)
[fn_(arg) for arg in flat_args]
return pytree
# TODO: replace all of these with pytrees
def _create_differentiable(tensor_or_tuple_of_tensors, level=None):
if isinstance(tensor_or_tuple_of_tensors, torch.Tensor):
@ -38,25 +49,22 @@ def _undo_create_differentiable(tensor_or_tuple_of_tensors, level=None):
return tuple(map(partial(_undo_create_differentiable, level=level), tensor_or_tuple_of_tensors))
assert False
def _is_differentiable(maybe_tensor):
if not isinstance(maybe_tensor, torch.Tensor):
return False
return maybe_tensor.requires_grad
def _any_differentiable(tensor_or_tuple_of_tensors):
if isinstance(tensor_or_tuple_of_tensors, torch.Tensor):
tensor = tensor_or_tuple_of_tensors
return tensor.requires_grad
if isinstance(tensor_or_tuple_of_tensors, tuple):
return any(tuple(map(_any_differentiable, tensor_or_tuple_of_tensors)))
if isinstance(tensor_or_tuple_of_tensors, list):
return any(tuple(map(_any_differentiable, tensor_or_tuple_of_tensors)))
return False
flat_args, _ = tree_unflatten(tensor_or_tuple_of_tensors)
return any(tuple(map(_is_differentiable, flat_args)))
def _wrap_tensor_for_grad(maybe_tensor, level):
if not isinstance(maybe_tensor, torch.Tensor):
return maybe_tensor
return _wrap_for_grad(maybe_tensor, level)
def _wrap_all_tensors(tensor_or_tuple_of_tensors, level):
if isinstance(tensor_or_tuple_of_tensors, torch.Tensor):
tensor = tensor_or_tuple_of_tensors
return _wrap_for_grad(tensor, level)
if isinstance(tensor_or_tuple_of_tensors, tuple):
return tuple(map(partial(_wrap_all_tensors, level=level), tensor_or_tuple_of_tensors))
if isinstance(tensor_or_tuple_of_tensors, list):
return tuple(map(partial(_wrap_all_tensors, level=level), tensor_or_tuple_of_tensors))
return tensor_or_tuple_of_tensors
return tree_map(partial(_wrap_tensor_for_grad, level=level), tensor_or_tuple_of_tensors)
# How do we increment and decrement the nesting? I don't think we can.
def vjp(f, *primals):
@ -89,36 +97,50 @@ def jacrev(f):
return result
return wrapper_fn
def grad_with_value(f, diff_argnums=(0,), has_aux=False):
def _safe_index(args, argnum):
if not isinstance(argnum, int):
raise RuntimeError(f'argnum must be int, got: {type(argnum)}')
if argnum >= 0 and argnum < len(args):
return args[argnum]
raise RuntimeError(f'Got argnum={argnum}, but only {len(args)} inputs')
def _slice_argnums(args, argnums):
if isinstance(argnums, int):
return _safe_index(args, argnums)
if isinstance(argnums, tuple):
return tuple(_safe_index(args, i) for i in argnums)
raise RuntimeError(f'argnums must be int or Tuple[int, ...], got: {type(argnums)}')
def grad_and_value(f, argnums=0, has_aux=False):
def wrapper(*args):
level = _grad_increment_nesting()
output, aux, grad_input = None, None, None
try:
args = _wrap_all_tensors(args, level)
args = [_create_differentiable(arg, level) if i in diff_argnums else arg
for i, arg in enumerate(args)]
# print("calling f(*args)")
diff_args = _slice_argnums(args, argnums)
tree_map_(partial(_create_differentiable, level=level), diff_args)
output = f(*args)
# print("done with f(*args)")
if has_aux:
output, aux = output
# print("calling output.dim()")
assert output.dim() == 0
diff_args = [args[i] for i in diff_argnums]
single_diff_arg = isinstance(diff_args[0], torch.Tensor) and len(diff_args) == 1
# TODO: quick hack...
if len(diff_args) == 1 and isinstance(diff_args[0], tuple):
diff_args = diff_args[0]
if not isinstance(output, torch.Tensor):
raise RuntimeError('grad_and_value(f)(*args): Expected f(*args)'
f'to return a Tensor, got {type(output)}')
if output.dim() != 0:
raise RuntimeError('grad_and_value(f)(*args): Expected f(*args)'
'to return a scalar Tensor, got tensor with '
f'{output.dim()} dims. Maybe you wanted to'
'use the vjp or jacrev APIs instead?')
flat_diff_args, spec = tree_flatten(diff_args)
# NB: need create_graph so that backward pass isn't run in no_grad mode
# import torchviz; import graphviz
# graph = torchviz.make_dot(output)
# graph.save("inner.dot")
# print("calling autograd.grad")
grad_input = torch.autograd.grad(
output, diff_args, create_graph=True)
# print("done-ish!")
if single_diff_arg:
grad_input = grad_input[0]
flat_grad_input = torch.autograd.grad(
output, flat_diff_args, create_graph=True)
grad_input = tree_unflatten(flat_grad_input, spec)
finally:
if grad_input is not None:
grad_input = _undo_create_differentiable(grad_input, level)
@ -132,10 +154,10 @@ def grad_with_value(f, diff_argnums=(0,), has_aux=False):
return grad_input, output
return wrapper
def grad(f, diff_argnums=(0,), has_aux=False):
def grad(f, argnums=0, has_aux=False):
@wraps(f)
def wrapper(*args):
results = grad_with_value(f, diff_argnums, has_aux=has_aux)(*args)
results = grad_and_value(f, argnums, has_aux=has_aux)(*args)
if has_aux:
return results[0], results[2]
return results[0]

View File

@ -15,7 +15,7 @@ from functools import partial
import functorch
from functorch import (
grad, vjp, vmap, jacrev, grad_with_value,
grad, vjp, vmap, jacrev, grad_and_value,
make_functional, make_functional_with_buffers,
)
@ -251,6 +251,33 @@ class TestGradTransform(TestCase):
result = grad(foo)(x)
self.assertEqual(result, x.cos())
def test_invalid_argnums(self, device):
x = torch.randn([])
y = torch.randn([])
with self.assertRaisesRegex(RuntimeError, 'but only'):
grad(torch.mul, argnums=-1)(x, y)
with self.assertRaisesRegex(RuntimeError, 'but only'):
grad(torch.mul, argnums=2)(x, y)
with self.assertRaisesRegex(RuntimeError, 'int or Tuple'):
grad(torch.mul, argnums=[0])(x, y)
with self.assertRaisesRegex(RuntimeError, 'must be int'):
grad(torch.mul, argnums=('0',))(x, y)
def test_argnums(self, device):
x = torch.randn([])
y = torch.randn([])
gx = grad(torch.mul, argnums=0)(x, y)
self.assertEqual(gx, y)
gy = grad(torch.mul, argnums=1)(x, y)
self.assertEqual(gy, x)
gx, = grad(torch.mul, argnums=(0,))(x, y)
self.assertEqual(gx, y)
gx, gy = grad(torch.mul, argnums=(0, 1))(x, y)
self.assertEqual(gx, y)
self.assertEqual(gy, x)
class TestVmapOfGrad(TestCase):
def test_per_sample_grads_inplace_view(self, device):
@ -672,7 +699,7 @@ class TestExamplesCorrectness(TestCase):
return loss
if use_transform:
grad_weights, loss = grad_with_value(compute_loss)(weights, batch, targets)
grad_weights, loss = grad_and_value(compute_loss)(weights, batch, targets)
else:
loss = compute_loss(weights, batch, targets)
grad_weights = torch.autograd.grad(loss, weights)