[functorch] adds support for grad if inputs don't depend on output (pytorch/functorch#6)

* adds support for grad if inputs don't depend on output

* fix some issues

* responded to comments
This commit is contained in:
Horace He
2021-05-03 10:43:20 -07:00
committed by Jon Janzen
parent 86e49cf0d7
commit 80131b937d
3 changed files with 29 additions and 23 deletions

View File

@ -17,28 +17,21 @@ from functorch._C import (
_grad_decrement_nesting,
)
# TODO: replace all of these with pytrees
def _create_differentiable(tensor_or_tuple_of_tensors, level=None):
if isinstance(tensor_or_tuple_of_tensors, torch.Tensor):
tensor = tensor_or_tuple_of_tensors
aliased = tensor
return aliased.requires_grad_()
if isinstance(tensor_or_tuple_of_tensors, tuple):
return tuple(map(partial(_create_differentiable, level=level), tensor_or_tuple_of_tensors))
if isinstance(tensor_or_tuple_of_tensors, list):
return tuple(map(partial(_create_differentiable, level=level), tensor_or_tuple_of_tensors))
raise ValueError(f'Thing passed to transform API must be Tensor, List or Tuple, '
f'got {type(tensor_or_tuple_of_tensors)}')
def _create_differentiable(inps, level=None):
def create_differentiable(x):
if isinstance(x, torch.Tensor):
return x.requires_grad_()
raise ValueError(f'Thing passed to transform API must be Tensor,'
f'got {type(x)}')
return tree_map(create_differentiable, inps)
def _undo_create_differentiable(tensor_or_tuple_of_tensors, level=None):
if isinstance(tensor_or_tuple_of_tensors, torch.Tensor):
tensor = tensor_or_tuple_of_tensors
return _unwrap_for_grad(tensor, level)
if isinstance(tensor_or_tuple_of_tensors, tuple):
return tuple(map(partial(_undo_create_differentiable, level=level), tensor_or_tuple_of_tensors))
if isinstance(tensor_or_tuple_of_tensors, list):
return tuple(map(partial(_undo_create_differentiable, level=level), tensor_or_tuple_of_tensors))
assert False
def _undo_create_differentiable(inps, level=None):
def unwrap_tensors(x):
if isinstance(x, torch.Tensor):
return _unwrap_for_grad(x, level)
assert False
return tree_map(unwrap_tensors, inps)
def _is_differentiable(maybe_tensor):
if not isinstance(maybe_tensor, torch.Tensor):
@ -128,8 +121,13 @@ def grad_and_value(f, argnums=0, has_aux=False):
# NB: need create_graph so that backward pass isn't run in no_grad mode
flat_grad_input = torch.autograd.grad(
output, flat_diff_args, create_graph=True)
output, flat_diff_args, create_graph=True, allow_unused=True)
def replace_none_with_zeros(grad_arg, orig_arg):
if grad_arg is None:
return torch.zeros_like(orig_arg)
return grad_arg
flat_grad_input = [replace_none_with_zeros(x, flat_diff_args[idx]) for idx, x in enumerate(flat_grad_input)]
grad_input = tree_unflatten(flat_grad_input, spec)
finally:

View File

@ -89,7 +89,6 @@ def wrap_key(f, inps):
@functools.wraps(f)
def wrapped(*args):
flat_args, args_spec = pytree.tree_flatten(args)
import pdb; pdb.set_trace()
assert(len(flat_args) == len(flat_inps))
for idx, arg in enumerate(flat_args):
if isinstance(flat_inps[idx], torch.Tensor):

View File

@ -279,6 +279,15 @@ class TestGradTransform(TestCase):
self.assertEqual(gx, y)
self.assertEqual(gy, x)
def test_zero_grad(self, device):
def f(x):
return (x['a']**2.0).sum()
inps = ({'a':torch.randn(10) + 3, 'b':torch.randn(10)})
grads = grad(f)(inps)
self.assertNotEqual(grads['a'].sum(), 0.0)
self.assertEqual(grads['b'].sum(), 0.0)
class TestVmapOfGrad(TestCase):
def test_per_sample_grads_inplace_view(self, device):
def compute_loss(weight, x, t):