Fix linearize(grad(...)) call (#133364)

Fixes #124550

Also moves `graph.eliminate_dead_code()` call to a few lines after
`_inline_module(...)` in `const_fold.py`

* Test plan:

Add a new test on `test_eager_transforms.py` to ensure the reported
issue was indeed fixed

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133364
Approved by: https://github.com/zou3519
This commit is contained in:
Guilherme Leobas
2024-08-15 15:09:53 +00:00
committed by PyTorch MergeBot
parent cfec69e2a1
commit 5ec9c0bc4a
3 changed files with 23 additions and 4 deletions

View File

@ -2882,7 +2882,7 @@ class TestLinearize(TestCase):
self.assertEqual(actual_jvp, expected_jvp)
@dtypes(torch.float)
def test_linearize_composition(self, device, dtype):
def test_linearize_composition_vmap(self, device, dtype):
x_p = make_tensor((3, 1), device=device, dtype=dtype)
x_t = make_tensor((3, 3, 1), device=device, dtype=dtype)
@ -2899,6 +2899,25 @@ class TestLinearize(TestCase):
self.assertEqual(actual_batched_jvp, expected_batched_jvp)
@dtypes(torch.float)
def test_linearize_composition_grad(self, device, dtype):
x_p = make_tensor((3,), device=device, dtype=dtype)
x_t = make_tensor((3,), device=device, dtype=dtype)
def fn(x):
z = torch.ones(3, device=device, dtype=dtype)
return grad(lambda x: z @ x)(x)
_, jvp_fn = linearize(fn, x_p)
actual_batched_jvp = jvp_fn(x_t)
def jvp_fn(x_t):
return jvp(fn, (x_p,), (x_t,))[1]
expected_batched_jvp = jvp_fn(x_t)
self.assertEqual(actual_batched_jvp, expected_batched_jvp)
@dtypes(torch.float)
def test_linearize_nested_input_nested_output(self, device, dtype):
x_p = make_tensor((3, 1), device=device, dtype=dtype)