[functorch] hessian API

This commit is contained in:
Richard Zou
2021-11-17 07:03:03 -08:00
committed by Jon Janzen
parent 4e75634a82
commit cdb56a0e62
3 changed files with 14 additions and 3 deletions

View File

@ -633,6 +633,9 @@ def jacfwd(f, argnums=0):
return tree_unflatten(jac_outs_ins, spec)
return wrapper_fn
def hessian(f, argnums=0):
return jacfwd(jacrev(f, argnums), argnums)
def grad_and_value(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Callable:
"""
Returns a function to compute a tuple of the gradient and primal, or

View File

@ -1,2 +1,2 @@
# PyTorch forward-mode is not mature yet
from .._src.eager_transforms import jvp, jacfwd
from .._src.eager_transforms import jvp, jacfwd, hessian

View File

@ -28,7 +28,7 @@ from functorch._src.make_functional import (
functional_init, functional_init_with_buffers,
)
from functorch.experimental import (
jvp, jacfwd,
jvp, jacfwd, hessian,
)
# NB: numpy is a testing dependency!
@ -811,7 +811,7 @@ class TestJac(TestCase):
assert torch.allclose(y, expected)
@FIXME_jacrev_only
def test_hessian_simple(self, device, jacapi):
def test_nested_jac_simple(self, device, jacapi):
def foo(x):
return x.sin().sum()
@ -1006,6 +1006,14 @@ class TestJac(TestCase):
with self.assertRaisesRegex(RuntimeError, "must be int"):
z = jacapi(torch.multiply, argnums=(1, 0.0))(x, x)
@unittest.expectedFailure
def test_hessian_simple(self, device):
def f(x):
return x.sin()
x = torch.randn(3, device=device)
result = hessian(f)(x)
class TestJvp(TestCase):
def test_inplace_on_captures(self, device):
x = torch.tensor([1., 2., 3.], device=device)