mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[functorch] Add pytree support for grad and grad_and_value
This commit is contained in:
@ -2,7 +2,7 @@ import torch
|
||||
from . import _C
|
||||
|
||||
from ._src.vmap import vmap
|
||||
from ._src.eager_transforms import grad, grad_with_value, vjp, jacrev
|
||||
from ._src.eager_transforms import grad, grad_and_value, vjp, jacrev
|
||||
from ._src.make_functional import make_functional, make_functional_with_buffers
|
||||
|
||||
# Monkeypatching lol
|
||||
|
@ -3,6 +3,7 @@ from functools import partial, wraps
|
||||
import collections
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.utils._pytree import tree_flatten, tree_unflatten
|
||||
import gc
|
||||
|
||||
from .vmap import vmap
|
||||
@ -15,6 +16,16 @@ from functorch._C import (
|
||||
_grad_decrement_nesting,
|
||||
)
|
||||
|
||||
# TODO: replace this with tree_map from core
|
||||
def tree_map(fn, pytree):
|
||||
flat_args, spec = tree_flatten(pytree)
|
||||
return tree_unflatten([fn(arg) for arg in flat_args], spec)
|
||||
|
||||
def tree_map_(fn_, pytree):
|
||||
flat_args, _ = tree_flatten(pytree)
|
||||
[fn_(arg) for arg in flat_args]
|
||||
return pytree
|
||||
|
||||
# TODO: replace all of these with pytrees
|
||||
def _create_differentiable(tensor_or_tuple_of_tensors, level=None):
|
||||
if isinstance(tensor_or_tuple_of_tensors, torch.Tensor):
|
||||
@ -38,25 +49,22 @@ def _undo_create_differentiable(tensor_or_tuple_of_tensors, level=None):
|
||||
return tuple(map(partial(_undo_create_differentiable, level=level), tensor_or_tuple_of_tensors))
|
||||
assert False
|
||||
|
||||
def _any_differentiable(tensor_or_tuple_of_tensors):
|
||||
if isinstance(tensor_or_tuple_of_tensors, torch.Tensor):
|
||||
tensor = tensor_or_tuple_of_tensors
|
||||
return tensor.requires_grad
|
||||
if isinstance(tensor_or_tuple_of_tensors, tuple):
|
||||
return any(tuple(map(_any_differentiable, tensor_or_tuple_of_tensors)))
|
||||
if isinstance(tensor_or_tuple_of_tensors, list):
|
||||
return any(tuple(map(_any_differentiable, tensor_or_tuple_of_tensors)))
|
||||
def _is_differentiable(maybe_tensor):
|
||||
if not isinstance(maybe_tensor, torch.Tensor):
|
||||
return False
|
||||
return maybe_tensor.requires_grad
|
||||
|
||||
def _any_differentiable(tensor_or_tuple_of_tensors):
|
||||
flat_args, _ = tree_unflatten(tensor_or_tuple_of_tensors)
|
||||
return any(tuple(map(_is_differentiable, flat_args)))
|
||||
|
||||
def _wrap_tensor_for_grad(maybe_tensor, level):
|
||||
if not isinstance(maybe_tensor, torch.Tensor):
|
||||
return maybe_tensor
|
||||
return _wrap_for_grad(maybe_tensor, level)
|
||||
|
||||
def _wrap_all_tensors(tensor_or_tuple_of_tensors, level):
|
||||
if isinstance(tensor_or_tuple_of_tensors, torch.Tensor):
|
||||
tensor = tensor_or_tuple_of_tensors
|
||||
return _wrap_for_grad(tensor, level)
|
||||
if isinstance(tensor_or_tuple_of_tensors, tuple):
|
||||
return tuple(map(partial(_wrap_all_tensors, level=level), tensor_or_tuple_of_tensors))
|
||||
if isinstance(tensor_or_tuple_of_tensors, list):
|
||||
return tuple(map(partial(_wrap_all_tensors, level=level), tensor_or_tuple_of_tensors))
|
||||
return tensor_or_tuple_of_tensors
|
||||
return tree_map(partial(_wrap_tensor_for_grad, level=level), tensor_or_tuple_of_tensors)
|
||||
|
||||
# How do we increment and decrement the nesting? I don't think we can.
|
||||
def vjp(f, *primals):
|
||||
@ -89,36 +97,50 @@ def jacrev(f):
|
||||
return result
|
||||
return wrapper_fn
|
||||
|
||||
def grad_with_value(f, diff_argnums=(0,), has_aux=False):
|
||||
def _safe_index(args, argnum):
|
||||
if not isinstance(argnum, int):
|
||||
raise RuntimeError(f'argnum must be int, got: {type(argnum)}')
|
||||
if argnum >= 0 and argnum < len(args):
|
||||
return args[argnum]
|
||||
raise RuntimeError(f'Got argnum={argnum}, but only {len(args)} inputs')
|
||||
|
||||
def _slice_argnums(args, argnums):
|
||||
if isinstance(argnums, int):
|
||||
return _safe_index(args, argnums)
|
||||
if isinstance(argnums, tuple):
|
||||
return tuple(_safe_index(args, i) for i in argnums)
|
||||
raise RuntimeError(f'argnums must be int or Tuple[int, ...], got: {type(argnums)}')
|
||||
|
||||
def grad_and_value(f, argnums=0, has_aux=False):
|
||||
def wrapper(*args):
|
||||
level = _grad_increment_nesting()
|
||||
output, aux, grad_input = None, None, None
|
||||
try:
|
||||
args = _wrap_all_tensors(args, level)
|
||||
args = [_create_differentiable(arg, level) if i in diff_argnums else arg
|
||||
for i, arg in enumerate(args)]
|
||||
# print("calling f(*args)")
|
||||
diff_args = _slice_argnums(args, argnums)
|
||||
tree_map_(partial(_create_differentiable, level=level), diff_args)
|
||||
|
||||
output = f(*args)
|
||||
# print("done with f(*args)")
|
||||
if has_aux:
|
||||
output, aux = output
|
||||
# print("calling output.dim()")
|
||||
assert output.dim() == 0
|
||||
diff_args = [args[i] for i in diff_argnums]
|
||||
single_diff_arg = isinstance(diff_args[0], torch.Tensor) and len(diff_args) == 1
|
||||
# TODO: quick hack...
|
||||
if len(diff_args) == 1 and isinstance(diff_args[0], tuple):
|
||||
diff_args = diff_args[0]
|
||||
|
||||
if not isinstance(output, torch.Tensor):
|
||||
raise RuntimeError('grad_and_value(f)(*args): Expected f(*args)'
|
||||
f'to return a Tensor, got {type(output)}')
|
||||
if output.dim() != 0:
|
||||
raise RuntimeError('grad_and_value(f)(*args): Expected f(*args)'
|
||||
'to return a scalar Tensor, got tensor with '
|
||||
f'{output.dim()} dims. Maybe you wanted to'
|
||||
'use the vjp or jacrev APIs instead?')
|
||||
|
||||
flat_diff_args, spec = tree_flatten(diff_args)
|
||||
|
||||
# NB: need create_graph so that backward pass isn't run in no_grad mode
|
||||
# import torchviz; import graphviz
|
||||
# graph = torchviz.make_dot(output)
|
||||
# graph.save("inner.dot")
|
||||
# print("calling autograd.grad")
|
||||
grad_input = torch.autograd.grad(
|
||||
output, diff_args, create_graph=True)
|
||||
# print("done-ish!")
|
||||
if single_diff_arg:
|
||||
grad_input = grad_input[0]
|
||||
flat_grad_input = torch.autograd.grad(
|
||||
output, flat_diff_args, create_graph=True)
|
||||
|
||||
grad_input = tree_unflatten(flat_grad_input, spec)
|
||||
|
||||
finally:
|
||||
if grad_input is not None:
|
||||
grad_input = _undo_create_differentiable(grad_input, level)
|
||||
@ -132,10 +154,10 @@ def grad_with_value(f, diff_argnums=(0,), has_aux=False):
|
||||
return grad_input, output
|
||||
return wrapper
|
||||
|
||||
def grad(f, diff_argnums=(0,), has_aux=False):
|
||||
def grad(f, argnums=0, has_aux=False):
|
||||
@wraps(f)
|
||||
def wrapper(*args):
|
||||
results = grad_with_value(f, diff_argnums, has_aux=has_aux)(*args)
|
||||
results = grad_and_value(f, argnums, has_aux=has_aux)(*args)
|
||||
if has_aux:
|
||||
return results[0], results[2]
|
||||
return results[0]
|
||||
|
@ -15,7 +15,7 @@ from functools import partial
|
||||
|
||||
import functorch
|
||||
from functorch import (
|
||||
grad, vjp, vmap, jacrev, grad_with_value,
|
||||
grad, vjp, vmap, jacrev, grad_and_value,
|
||||
make_functional, make_functional_with_buffers,
|
||||
)
|
||||
|
||||
@ -251,6 +251,33 @@ class TestGradTransform(TestCase):
|
||||
result = grad(foo)(x)
|
||||
self.assertEqual(result, x.cos())
|
||||
|
||||
def test_invalid_argnums(self, device):
|
||||
x = torch.randn([])
|
||||
y = torch.randn([])
|
||||
with self.assertRaisesRegex(RuntimeError, 'but only'):
|
||||
grad(torch.mul, argnums=-1)(x, y)
|
||||
with self.assertRaisesRegex(RuntimeError, 'but only'):
|
||||
grad(torch.mul, argnums=2)(x, y)
|
||||
with self.assertRaisesRegex(RuntimeError, 'int or Tuple'):
|
||||
grad(torch.mul, argnums=[0])(x, y)
|
||||
with self.assertRaisesRegex(RuntimeError, 'must be int'):
|
||||
grad(torch.mul, argnums=('0',))(x, y)
|
||||
|
||||
def test_argnums(self, device):
|
||||
x = torch.randn([])
|
||||
y = torch.randn([])
|
||||
gx = grad(torch.mul, argnums=0)(x, y)
|
||||
self.assertEqual(gx, y)
|
||||
|
||||
gy = grad(torch.mul, argnums=1)(x, y)
|
||||
self.assertEqual(gy, x)
|
||||
|
||||
gx, = grad(torch.mul, argnums=(0,))(x, y)
|
||||
self.assertEqual(gx, y)
|
||||
|
||||
gx, gy = grad(torch.mul, argnums=(0, 1))(x, y)
|
||||
self.assertEqual(gx, y)
|
||||
self.assertEqual(gy, x)
|
||||
|
||||
class TestVmapOfGrad(TestCase):
|
||||
def test_per_sample_grads_inplace_view(self, device):
|
||||
@ -672,7 +699,7 @@ class TestExamplesCorrectness(TestCase):
|
||||
return loss
|
||||
|
||||
if use_transform:
|
||||
grad_weights, loss = grad_with_value(compute_loss)(weights, batch, targets)
|
||||
grad_weights, loss = grad_and_value(compute_loss)(weights, batch, targets)
|
||||
else:
|
||||
loss = compute_loss(weights, batch, targets)
|
||||
grad_weights = torch.autograd.grad(loss, weights)
|
||||
|
Reference in New Issue
Block a user