[functorch] Fix vjp and grad for unrelated outputs

Fixes pytorch/functorch#1.
This commit is contained in:
Richard Zou
2021-05-04 16:33:47 -07:00
committed by Jon Janzen
parent ec99d21d1e
commit 9490aa0b65
3 changed files with 88 additions and 13 deletions

View File

@ -50,6 +50,32 @@ def _wrap_tensor_for_grad(maybe_tensor, level):
def _wrap_all_tensors(tensor_or_tuple_of_tensors, level):
return tree_map(partial(_wrap_tensor_for_grad, level=level), tensor_or_tuple_of_tensors)
def _as_tuple(val):
if isinstance(val, tuple):
return val
return (val,)
# Version of autograd.grad that handles outputs that don't depend on inputs
def _autograd_grad(outputs, inputs, grad_outputs=None, retain_graph=False, create_graph=True):
outputs = _as_tuple(outputs)
if grad_outputs is None:
diff_outputs = tuple(out for out in outputs if out.requires_grad)
else:
result = tuple((out, go) for out, go in zip(outputs, grad_outputs) if out.requires_grad)
if len(result) == 0:
diff_outputs, grad_outputs = (), ()
else:
diff_outputs, grad_outputs = zip(*result)
if len(diff_outputs) == 0:
return tuple(torch.zeros_like(inp) for inp in inputs)
grad_inputs = torch.autograd.grad(diff_outputs, inputs, grad_outputs,
retain_graph=retain_graph,
create_graph=create_graph,
allow_unused=True)
grad_inputs = tuple(torch.zeros_like(inp) if gi is None else gi
for gi, inp in zip(grad_inputs, inputs))
return grad_inputs
# How do we increment and decrement the nesting? I don't think we can.
def vjp(f, *primals):
level = _grad_increment_nesting()
@ -60,8 +86,13 @@ def vjp(f, *primals):
results = _undo_create_differentiable(primals_out, level)
def wrapper(*cotangents, retain_graph=True, create_graph=True):
result = torch.autograd.grad(primals_out, diff_primals, cotangents,
retain_graph=retain_graph, create_graph=create_graph)
primals_out_tuple = _as_tuple(primals_out)
if len(primals_out_tuple) != len(cotangents):
raise RuntimeError(
f'Got {len(primals_out_tuple)} outputs but {len(cotangents)} '
f'cotangents. These two quantities should be the same')
result = _autograd_grad(primals_out_tuple, diff_primals, cotangents,
retain_graph=retain_graph, create_graph=create_graph)
return result
finally:
@ -120,14 +151,7 @@ def grad_and_value(f, argnums=0, has_aux=False):
flat_diff_args, spec = tree_flatten(diff_args)
# NB: need create_graph so that backward pass isn't run in no_grad mode
flat_grad_input = torch.autograd.grad(
output, flat_diff_args, create_graph=True, allow_unused=True)
def replace_none_with_zeros(grad_arg, orig_arg):
if grad_arg is None:
return torch.zeros_like(orig_arg)
return grad_arg
flat_grad_input = [replace_none_with_zeros(x, flat_diff_args[idx]) for idx, x in enumerate(flat_grad_input)]
flat_grad_input = _autograd_grad(output, flat_diff_args, create_graph=True)
grad_input = tree_unflatten(flat_grad_input, spec)
finally:

View File

@ -9,7 +9,7 @@ import warnings
import math
from typing import Callable, Type
from torch.testing._internal.common_device_type import instantiate_device_type_tests, \
skipCUDAIfNoMagma, onlyOnCPUAndCUDA
skipCUDAIfNoMagma, onlyOnCPUAndCUDA, onlyCPU
import types
from functools import partial
@ -282,11 +282,63 @@ class TestGradTransform(TestCase):
def test_zero_grad(self, device):
def f(x):
return (x['a']**2.0).sum()
inps = ({'a':torch.randn(10) + 3, 'b':torch.randn(10)})
inps = ({'a':torch.randn(10, device=device) + 3, 'b':torch.randn(10, device=device)})
grads = grad(f)(inps)
self.assertNotEqual(grads['a'].sum(), 0.0)
self.assertEqual(grads['b'].sum(), 0.0)
def test_unrelated_grad(self, device):
x = torch.tensor(1., device=device)
y = torch.tensor(2., device=device)
def unrelated(x):
return y
result = grad(unrelated)(x)
self.assertEqual(result, torch.zeros_like(x))
def test_unrelated_vjp(self, device):
x = torch.tensor(1., device=device)
y = torch.tensor(2., device=device)
v = torch.tensor(1., device=device)
def unrelated(x):
return y
out, vjp_fn = vjp(unrelated, x)
result = vjp_fn(v)
expected = (torch.zeros_like(x),)
self.assertEqual(result, expected)
def test_unrelated_vjp_multiple_inputs_outputs(self, device):
w = torch.tensor(3., device=device)
x = torch.tensor(4., device=device)
y = torch.tensor(2., device=device)
v = torch.tensor(1., device=device)
def unrelated(w, x):
return y, y, x
out, vjp_fn = vjp(unrelated, w, x)
result = vjp_fn(v, v, v)
expected = (torch.zeros_like(x), torch.ones_like(x))
self.assertEqual(result, expected)
# TODO: https://github.com/zou3519/functorch/issues/12
@onlyCPU
def test_unrelated_hessian(self, device):
N = 5
M = 3
W = torch.randn(N, M, device=device)
def f(x):
return W @ x
x = torch.randn(M)
result = jacrev(jacrev(f))(x)
expected = torch.zeros(N, M, M, device=device)
self.assertEqual(result, expected)
class TestVmapOfGrad(TestCase):
def test_per_sample_grads_inplace_view(self, device):

View File

@ -522,7 +522,6 @@ class TestVmapAPI(TestCase):
vmap(foo, in_dims=(0,))(torch.randn(2, 3))
vmap(foo, in_dims=(1,))(torch.randn(2, 3))
@unittest.expectedFailure
def test_fallback_does_not_warn_by_default(self):
# NB: One day we will implement a batching rule for torch.atan2.
# If/when we do, this test should be replaced to test the fallback