mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[functorch] adds support for grad if inputs don't depend on output (pytorch/functorch#6)
* adds support for grad if inputs don't depend on output * fix some issues * responded to comments
This commit is contained in:
@ -17,28 +17,21 @@ from functorch._C import (
|
||||
_grad_decrement_nesting,
|
||||
)
|
||||
|
||||
# TODO: replace all of these with pytrees
|
||||
def _create_differentiable(tensor_or_tuple_of_tensors, level=None):
|
||||
if isinstance(tensor_or_tuple_of_tensors, torch.Tensor):
|
||||
tensor = tensor_or_tuple_of_tensors
|
||||
aliased = tensor
|
||||
return aliased.requires_grad_()
|
||||
if isinstance(tensor_or_tuple_of_tensors, tuple):
|
||||
return tuple(map(partial(_create_differentiable, level=level), tensor_or_tuple_of_tensors))
|
||||
if isinstance(tensor_or_tuple_of_tensors, list):
|
||||
return tuple(map(partial(_create_differentiable, level=level), tensor_or_tuple_of_tensors))
|
||||
raise ValueError(f'Thing passed to transform API must be Tensor, List or Tuple, '
|
||||
f'got {type(tensor_or_tuple_of_tensors)}')
|
||||
def _create_differentiable(inps, level=None):
|
||||
def create_differentiable(x):
|
||||
if isinstance(x, torch.Tensor):
|
||||
return x.requires_grad_()
|
||||
raise ValueError(f'Thing passed to transform API must be Tensor,'
|
||||
f'got {type(x)}')
|
||||
return tree_map(create_differentiable, inps)
|
||||
|
||||
def _undo_create_differentiable(tensor_or_tuple_of_tensors, level=None):
|
||||
if isinstance(tensor_or_tuple_of_tensors, torch.Tensor):
|
||||
tensor = tensor_or_tuple_of_tensors
|
||||
return _unwrap_for_grad(tensor, level)
|
||||
if isinstance(tensor_or_tuple_of_tensors, tuple):
|
||||
return tuple(map(partial(_undo_create_differentiable, level=level), tensor_or_tuple_of_tensors))
|
||||
if isinstance(tensor_or_tuple_of_tensors, list):
|
||||
return tuple(map(partial(_undo_create_differentiable, level=level), tensor_or_tuple_of_tensors))
|
||||
assert False
|
||||
def _undo_create_differentiable(inps, level=None):
|
||||
def unwrap_tensors(x):
|
||||
if isinstance(x, torch.Tensor):
|
||||
return _unwrap_for_grad(x, level)
|
||||
assert False
|
||||
|
||||
return tree_map(unwrap_tensors, inps)
|
||||
|
||||
def _is_differentiable(maybe_tensor):
|
||||
if not isinstance(maybe_tensor, torch.Tensor):
|
||||
@ -128,8 +121,13 @@ def grad_and_value(f, argnums=0, has_aux=False):
|
||||
|
||||
# NB: need create_graph so that backward pass isn't run in no_grad mode
|
||||
flat_grad_input = torch.autograd.grad(
|
||||
output, flat_diff_args, create_graph=True)
|
||||
output, flat_diff_args, create_graph=True, allow_unused=True)
|
||||
|
||||
def replace_none_with_zeros(grad_arg, orig_arg):
|
||||
if grad_arg is None:
|
||||
return torch.zeros_like(orig_arg)
|
||||
return grad_arg
|
||||
flat_grad_input = [replace_none_with_zeros(x, flat_diff_args[idx]) for idx, x in enumerate(flat_grad_input)]
|
||||
grad_input = tree_unflatten(flat_grad_input, spec)
|
||||
|
||||
finally:
|
||||
|
@ -89,7 +89,6 @@ def wrap_key(f, inps):
|
||||
@functools.wraps(f)
|
||||
def wrapped(*args):
|
||||
flat_args, args_spec = pytree.tree_flatten(args)
|
||||
import pdb; pdb.set_trace()
|
||||
assert(len(flat_args) == len(flat_inps))
|
||||
for idx, arg in enumerate(flat_args):
|
||||
if isinstance(flat_inps[idx], torch.Tensor):
|
||||
|
@ -279,6 +279,15 @@ class TestGradTransform(TestCase):
|
||||
self.assertEqual(gx, y)
|
||||
self.assertEqual(gy, x)
|
||||
|
||||
def test_zero_grad(self, device):
|
||||
def f(x):
|
||||
return (x['a']**2.0).sum()
|
||||
inps = ({'a':torch.randn(10) + 3, 'b':torch.randn(10)})
|
||||
grads = grad(f)(inps)
|
||||
self.assertNotEqual(grads['a'].sum(), 0.0)
|
||||
self.assertEqual(grads['b'].sum(), 0.0)
|
||||
|
||||
|
||||
class TestVmapOfGrad(TestCase):
|
||||
def test_per_sample_grads_inplace_view(self, device):
|
||||
def compute_loss(weight, x, t):
|
||||
|
Reference in New Issue
Block a user