mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix linearize(grad(...))
call (#133364)
Fixes #124550 Also moves `graph.eliminate_dead_code()` call to a few lines after `_inline_module(...)` in `const_fold.py` * Test plan: Add a new test on `test_eager_transforms.py` to ensure the reported issue was indeed fixed Pull Request resolved: https://github.com/pytorch/pytorch/pull/133364 Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
cfec69e2a1
commit
5ec9c0bc4a
@ -2882,7 +2882,7 @@ class TestLinearize(TestCase):
|
||||
self.assertEqual(actual_jvp, expected_jvp)
|
||||
|
||||
@dtypes(torch.float)
|
||||
def test_linearize_composition(self, device, dtype):
|
||||
def test_linearize_composition_vmap(self, device, dtype):
|
||||
x_p = make_tensor((3, 1), device=device, dtype=dtype)
|
||||
x_t = make_tensor((3, 3, 1), device=device, dtype=dtype)
|
||||
|
||||
@ -2899,6 +2899,25 @@ class TestLinearize(TestCase):
|
||||
|
||||
self.assertEqual(actual_batched_jvp, expected_batched_jvp)
|
||||
|
||||
@dtypes(torch.float)
|
||||
def test_linearize_composition_grad(self, device, dtype):
|
||||
x_p = make_tensor((3,), device=device, dtype=dtype)
|
||||
x_t = make_tensor((3,), device=device, dtype=dtype)
|
||||
|
||||
def fn(x):
|
||||
z = torch.ones(3, device=device, dtype=dtype)
|
||||
return grad(lambda x: z @ x)(x)
|
||||
|
||||
_, jvp_fn = linearize(fn, x_p)
|
||||
actual_batched_jvp = jvp_fn(x_t)
|
||||
|
||||
def jvp_fn(x_t):
|
||||
return jvp(fn, (x_p,), (x_t,))[1]
|
||||
|
||||
expected_batched_jvp = jvp_fn(x_t)
|
||||
|
||||
self.assertEqual(actual_batched_jvp, expected_batched_jvp)
|
||||
|
||||
@dtypes(torch.float)
|
||||
def test_linearize_nested_input_nested_output(self, device, dtype):
|
||||
x_p = make_tensor((3, 1), device=device, dtype=dtype)
|
||||
|
@ -1781,7 +1781,7 @@ def linearize(func: Callable, *primals) -> Tuple[Any, Callable]:
|
||||
duals = tree_unflatten(flat_duals, primals_argspec)
|
||||
output = func(*duals)
|
||||
tangents = tree_map_only(
|
||||
torch.Tensor, lambda t: fwAD.unpack_dual(t)[1], output
|
||||
torch.Tensor, lambda dual: safe_unpack_dual(dual, False)[1], output
|
||||
)
|
||||
|
||||
return tangents
|
||||
|
@ -275,14 +275,14 @@ def split_const_subgraphs(
|
||||
node.replace_all_uses_with(folded_attrs)
|
||||
break
|
||||
|
||||
split.graph.eliminate_dead_code()
|
||||
|
||||
# Finally, inline the non-constant submod (if it exists) into the split submod.
|
||||
# This is so that the original caller who may have passed in a graph module will
|
||||
# get back out a graph module whose graph is traced to the same granularity.
|
||||
if hasattr(split, non_const_mod_name):
|
||||
_inline_module(split, non_const_mod_name)
|
||||
|
||||
split.graph.eliminate_dead_code()
|
||||
|
||||
return FoldedGraphModule(
|
||||
split,
|
||||
split.graph,
|
||||
|
Reference in New Issue
Block a user