mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-02 06:24:59 +08:00
[functorch] hessian API
This commit is contained in:
@ -633,6 +633,9 @@ def jacfwd(f, argnums=0):
|
||||
return tree_unflatten(jac_outs_ins, spec)
|
||||
return wrapper_fn
|
||||
|
||||
def hessian(f, argnums=0):
|
||||
return jacfwd(jacrev(f, argnums), argnums)
|
||||
|
||||
def grad_and_value(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Callable:
|
||||
"""
|
||||
Returns a function to compute a tuple of the gradient and primal, or
|
||||
|
||||
@ -1,2 +1,2 @@
|
||||
# PyTorch forward-mode is not mature yet
|
||||
from .._src.eager_transforms import jvp, jacfwd
|
||||
from .._src.eager_transforms import jvp, jacfwd, hessian
|
||||
|
||||
@ -28,7 +28,7 @@ from functorch._src.make_functional import (
|
||||
functional_init, functional_init_with_buffers,
|
||||
)
|
||||
from functorch.experimental import (
|
||||
jvp, jacfwd,
|
||||
jvp, jacfwd, hessian,
|
||||
)
|
||||
|
||||
# NB: numpy is a testing dependency!
|
||||
@ -811,7 +811,7 @@ class TestJac(TestCase):
|
||||
assert torch.allclose(y, expected)
|
||||
|
||||
@FIXME_jacrev_only
|
||||
def test_hessian_simple(self, device, jacapi):
|
||||
def test_nested_jac_simple(self, device, jacapi):
|
||||
def foo(x):
|
||||
return x.sin().sum()
|
||||
|
||||
@ -1006,6 +1006,14 @@ class TestJac(TestCase):
|
||||
with self.assertRaisesRegex(RuntimeError, "must be int"):
|
||||
z = jacapi(torch.multiply, argnums=(1, 0.0))(x, x)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_hessian_simple(self, device):
|
||||
def f(x):
|
||||
return x.sin()
|
||||
|
||||
x = torch.randn(3, device=device)
|
||||
result = hessian(f)(x)
|
||||
|
||||
class TestJvp(TestCase):
|
||||
def test_inplace_on_captures(self, device):
|
||||
x = torch.tensor([1., 2., 3.], device=device)
|
||||
|
||||
Reference in New Issue
Block a user