mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix linearize(grad(...))
call (#133364)
Fixes #124550 Also moves `graph.eliminate_dead_code()` call to a few lines after `_inline_module(...)` in `const_fold.py` * Test plan: Add a new test on `test_eager_transforms.py` to ensure the reported issue was indeed fixed Pull Request resolved: https://github.com/pytorch/pytorch/pull/133364 Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
cfec69e2a1
commit
5ec9c0bc4a
@ -2882,7 +2882,7 @@ class TestLinearize(TestCase):
|
||||
self.assertEqual(actual_jvp, expected_jvp)
|
||||
|
||||
@dtypes(torch.float)
|
||||
def test_linearize_composition(self, device, dtype):
|
||||
def test_linearize_composition_vmap(self, device, dtype):
|
||||
x_p = make_tensor((3, 1), device=device, dtype=dtype)
|
||||
x_t = make_tensor((3, 3, 1), device=device, dtype=dtype)
|
||||
|
||||
@ -2899,6 +2899,25 @@ class TestLinearize(TestCase):
|
||||
|
||||
self.assertEqual(actual_batched_jvp, expected_batched_jvp)
|
||||
|
||||
@dtypes(torch.float)
|
||||
def test_linearize_composition_grad(self, device, dtype):
|
||||
x_p = make_tensor((3,), device=device, dtype=dtype)
|
||||
x_t = make_tensor((3,), device=device, dtype=dtype)
|
||||
|
||||
def fn(x):
|
||||
z = torch.ones(3, device=device, dtype=dtype)
|
||||
return grad(lambda x: z @ x)(x)
|
||||
|
||||
_, jvp_fn = linearize(fn, x_p)
|
||||
actual_batched_jvp = jvp_fn(x_t)
|
||||
|
||||
def jvp_fn(x_t):
|
||||
return jvp(fn, (x_p,), (x_t,))[1]
|
||||
|
||||
expected_batched_jvp = jvp_fn(x_t)
|
||||
|
||||
self.assertEqual(actual_batched_jvp, expected_batched_jvp)
|
||||
|
||||
@dtypes(torch.float)
|
||||
def test_linearize_nested_input_nested_output(self, device, dtype):
|
||||
x_p = make_tensor((3, 1), device=device, dtype=dtype)
|
||||
|
Reference in New Issue
Block a user