mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[functorch] Fix vjp and grad for unrelated outputs
Fixes pytorch/functorch#1.
This commit is contained in:
@ -50,6 +50,32 @@ def _wrap_tensor_for_grad(maybe_tensor, level):
|
||||
def _wrap_all_tensors(tensor_or_tuple_of_tensors, level):
|
||||
return tree_map(partial(_wrap_tensor_for_grad, level=level), tensor_or_tuple_of_tensors)
|
||||
|
||||
def _as_tuple(val):
|
||||
if isinstance(val, tuple):
|
||||
return val
|
||||
return (val,)
|
||||
|
||||
# Version of autograd.grad that handles outputs that don't depend on inputs
|
||||
def _autograd_grad(outputs, inputs, grad_outputs=None, retain_graph=False, create_graph=True):
|
||||
outputs = _as_tuple(outputs)
|
||||
if grad_outputs is None:
|
||||
diff_outputs = tuple(out for out in outputs if out.requires_grad)
|
||||
else:
|
||||
result = tuple((out, go) for out, go in zip(outputs, grad_outputs) if out.requires_grad)
|
||||
if len(result) == 0:
|
||||
diff_outputs, grad_outputs = (), ()
|
||||
else:
|
||||
diff_outputs, grad_outputs = zip(*result)
|
||||
if len(diff_outputs) == 0:
|
||||
return tuple(torch.zeros_like(inp) for inp in inputs)
|
||||
grad_inputs = torch.autograd.grad(diff_outputs, inputs, grad_outputs,
|
||||
retain_graph=retain_graph,
|
||||
create_graph=create_graph,
|
||||
allow_unused=True)
|
||||
grad_inputs = tuple(torch.zeros_like(inp) if gi is None else gi
|
||||
for gi, inp in zip(grad_inputs, inputs))
|
||||
return grad_inputs
|
||||
|
||||
# How do we increment and decrement the nesting? I don't think we can.
|
||||
def vjp(f, *primals):
|
||||
level = _grad_increment_nesting()
|
||||
@ -60,8 +86,13 @@ def vjp(f, *primals):
|
||||
results = _undo_create_differentiable(primals_out, level)
|
||||
|
||||
def wrapper(*cotangents, retain_graph=True, create_graph=True):
|
||||
result = torch.autograd.grad(primals_out, diff_primals, cotangents,
|
||||
retain_graph=retain_graph, create_graph=create_graph)
|
||||
primals_out_tuple = _as_tuple(primals_out)
|
||||
if len(primals_out_tuple) != len(cotangents):
|
||||
raise RuntimeError(
|
||||
f'Got {len(primals_out_tuple)} outputs but {len(cotangents)} '
|
||||
f'cotangents. These two quantities should be the same')
|
||||
result = _autograd_grad(primals_out_tuple, diff_primals, cotangents,
|
||||
retain_graph=retain_graph, create_graph=create_graph)
|
||||
return result
|
||||
|
||||
finally:
|
||||
@ -120,14 +151,7 @@ def grad_and_value(f, argnums=0, has_aux=False):
|
||||
flat_diff_args, spec = tree_flatten(diff_args)
|
||||
|
||||
# NB: need create_graph so that backward pass isn't run in no_grad mode
|
||||
flat_grad_input = torch.autograd.grad(
|
||||
output, flat_diff_args, create_graph=True, allow_unused=True)
|
||||
|
||||
def replace_none_with_zeros(grad_arg, orig_arg):
|
||||
if grad_arg is None:
|
||||
return torch.zeros_like(orig_arg)
|
||||
return grad_arg
|
||||
flat_grad_input = [replace_none_with_zeros(x, flat_diff_args[idx]) for idx, x in enumerate(flat_grad_input)]
|
||||
flat_grad_input = _autograd_grad(output, flat_diff_args, create_graph=True)
|
||||
grad_input = tree_unflatten(flat_grad_input, spec)
|
||||
|
||||
finally:
|
||||
|
@ -9,7 +9,7 @@ import warnings
|
||||
import math
|
||||
from typing import Callable, Type
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests, \
|
||||
skipCUDAIfNoMagma, onlyOnCPUAndCUDA
|
||||
skipCUDAIfNoMagma, onlyOnCPUAndCUDA, onlyCPU
|
||||
import types
|
||||
from functools import partial
|
||||
|
||||
@ -282,11 +282,63 @@ class TestGradTransform(TestCase):
|
||||
def test_zero_grad(self, device):
|
||||
def f(x):
|
||||
return (x['a']**2.0).sum()
|
||||
inps = ({'a':torch.randn(10) + 3, 'b':torch.randn(10)})
|
||||
inps = ({'a':torch.randn(10, device=device) + 3, 'b':torch.randn(10, device=device)})
|
||||
grads = grad(f)(inps)
|
||||
self.assertNotEqual(grads['a'].sum(), 0.0)
|
||||
self.assertEqual(grads['b'].sum(), 0.0)
|
||||
|
||||
def test_unrelated_grad(self, device):
|
||||
x = torch.tensor(1., device=device)
|
||||
y = torch.tensor(2., device=device)
|
||||
|
||||
def unrelated(x):
|
||||
return y
|
||||
|
||||
result = grad(unrelated)(x)
|
||||
self.assertEqual(result, torch.zeros_like(x))
|
||||
|
||||
def test_unrelated_vjp(self, device):
|
||||
x = torch.tensor(1., device=device)
|
||||
y = torch.tensor(2., device=device)
|
||||
v = torch.tensor(1., device=device)
|
||||
|
||||
def unrelated(x):
|
||||
return y
|
||||
|
||||
out, vjp_fn = vjp(unrelated, x)
|
||||
result = vjp_fn(v)
|
||||
expected = (torch.zeros_like(x),)
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_unrelated_vjp_multiple_inputs_outputs(self, device):
|
||||
w = torch.tensor(3., device=device)
|
||||
x = torch.tensor(4., device=device)
|
||||
y = torch.tensor(2., device=device)
|
||||
v = torch.tensor(1., device=device)
|
||||
|
||||
def unrelated(w, x):
|
||||
return y, y, x
|
||||
|
||||
out, vjp_fn = vjp(unrelated, w, x)
|
||||
result = vjp_fn(v, v, v)
|
||||
expected = (torch.zeros_like(x), torch.ones_like(x))
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
# TODO: https://github.com/zou3519/functorch/issues/12
|
||||
@onlyCPU
|
||||
def test_unrelated_hessian(self, device):
|
||||
N = 5
|
||||
M = 3
|
||||
W = torch.randn(N, M, device=device)
|
||||
|
||||
def f(x):
|
||||
return W @ x
|
||||
|
||||
x = torch.randn(M)
|
||||
result = jacrev(jacrev(f))(x)
|
||||
expected = torch.zeros(N, M, M, device=device)
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
|
||||
class TestVmapOfGrad(TestCase):
|
||||
def test_per_sample_grads_inplace_view(self, device):
|
||||
|
@ -522,7 +522,6 @@ class TestVmapAPI(TestCase):
|
||||
vmap(foo, in_dims=(0,))(torch.randn(2, 3))
|
||||
vmap(foo, in_dims=(1,))(torch.randn(2, 3))
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_fallback_does_not_warn_by_default(self):
|
||||
# NB: One day we will implement a batching rule for torch.atan2.
|
||||
# If/when we do, this test should be replaced to test the fallback
|
||||
|
Reference in New Issue
Block a user