Fix linearize(grad(...)) call (#133364)

Fixes #124550

Also moves `graph.eliminate_dead_code()` call to a few lines after
`_inline_module(...)` in `const_fold.py`

* Test plan:

Add a new test on `test_eager_transforms.py` to ensure the reported
issue was indeed fixed

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133364
Approved by: https://github.com/zou3519
This commit is contained in:
Guilherme Leobas
2024-08-15 15:09:53 +00:00
committed by PyTorch MergeBot
parent cfec69e2a1
commit 5ec9c0bc4a
3 changed files with 23 additions and 4 deletions

View File

@ -2882,7 +2882,7 @@ class TestLinearize(TestCase):
self.assertEqual(actual_jvp, expected_jvp)
@dtypes(torch.float)
def test_linearize_composition(self, device, dtype):
def test_linearize_composition_vmap(self, device, dtype):
x_p = make_tensor((3, 1), device=device, dtype=dtype)
x_t = make_tensor((3, 3, 1), device=device, dtype=dtype)
@ -2899,6 +2899,25 @@ class TestLinearize(TestCase):
self.assertEqual(actual_batched_jvp, expected_batched_jvp)
@dtypes(torch.float)
def test_linearize_composition_grad(self, device, dtype):
x_p = make_tensor((3,), device=device, dtype=dtype)
x_t = make_tensor((3,), device=device, dtype=dtype)
def fn(x):
z = torch.ones(3, device=device, dtype=dtype)
return grad(lambda x: z @ x)(x)
_, jvp_fn = linearize(fn, x_p)
actual_batched_jvp = jvp_fn(x_t)
def jvp_fn(x_t):
return jvp(fn, (x_p,), (x_t,))[1]
expected_batched_jvp = jvp_fn(x_t)
self.assertEqual(actual_batched_jvp, expected_batched_jvp)
@dtypes(torch.float)
def test_linearize_nested_input_nested_output(self, device, dtype):
x_p = make_tensor((3, 1), device=device, dtype=dtype)

View File

@ -1781,7 +1781,7 @@ def linearize(func: Callable, *primals) -> Tuple[Any, Callable]:
duals = tree_unflatten(flat_duals, primals_argspec)
output = func(*duals)
tangents = tree_map_only(
torch.Tensor, lambda t: fwAD.unpack_dual(t)[1], output
torch.Tensor, lambda dual: safe_unpack_dual(dual, False)[1], output
)
return tangents

View File

@ -275,14 +275,14 @@ def split_const_subgraphs(
node.replace_all_uses_with(folded_attrs)
break
split.graph.eliminate_dead_code()
# Finally, inline the non-constant submod (if it exists) into the split submod.
# This is so that the original caller who may have passed in a graph module will
# get back out a graph module whose graph is traced to the same granularity.
if hasattr(split, non_const_mod_name):
_inline_module(split, non_const_mod_name)
split.graph.eliminate_dead_code()
return FoldedGraphModule(
split,
split.graph,