mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix linearize(grad(...))
call (#133364)
Fixes #124550 Also moves `graph.eliminate_dead_code()` call to a few lines after `_inline_module(...)` in `const_fold.py` * Test plan: Add a new test on `test_eager_transforms.py` to ensure the reported issue was indeed fixed Pull Request resolved: https://github.com/pytorch/pytorch/pull/133364 Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
cfec69e2a1
commit
5ec9c0bc4a
@ -2882,7 +2882,7 @@ class TestLinearize(TestCase):
|
|||||||
self.assertEqual(actual_jvp, expected_jvp)
|
self.assertEqual(actual_jvp, expected_jvp)
|
||||||
|
|
||||||
@dtypes(torch.float)
|
@dtypes(torch.float)
|
||||||
def test_linearize_composition(self, device, dtype):
|
def test_linearize_composition_vmap(self, device, dtype):
|
||||||
x_p = make_tensor((3, 1), device=device, dtype=dtype)
|
x_p = make_tensor((3, 1), device=device, dtype=dtype)
|
||||||
x_t = make_tensor((3, 3, 1), device=device, dtype=dtype)
|
x_t = make_tensor((3, 3, 1), device=device, dtype=dtype)
|
||||||
|
|
||||||
@ -2899,6 +2899,25 @@ class TestLinearize(TestCase):
|
|||||||
|
|
||||||
self.assertEqual(actual_batched_jvp, expected_batched_jvp)
|
self.assertEqual(actual_batched_jvp, expected_batched_jvp)
|
||||||
|
|
||||||
|
@dtypes(torch.float)
|
||||||
|
def test_linearize_composition_grad(self, device, dtype):
|
||||||
|
x_p = make_tensor((3,), device=device, dtype=dtype)
|
||||||
|
x_t = make_tensor((3,), device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def fn(x):
|
||||||
|
z = torch.ones(3, device=device, dtype=dtype)
|
||||||
|
return grad(lambda x: z @ x)(x)
|
||||||
|
|
||||||
|
_, jvp_fn = linearize(fn, x_p)
|
||||||
|
actual_batched_jvp = jvp_fn(x_t)
|
||||||
|
|
||||||
|
def jvp_fn(x_t):
|
||||||
|
return jvp(fn, (x_p,), (x_t,))[1]
|
||||||
|
|
||||||
|
expected_batched_jvp = jvp_fn(x_t)
|
||||||
|
|
||||||
|
self.assertEqual(actual_batched_jvp, expected_batched_jvp)
|
||||||
|
|
||||||
@dtypes(torch.float)
|
@dtypes(torch.float)
|
||||||
def test_linearize_nested_input_nested_output(self, device, dtype):
|
def test_linearize_nested_input_nested_output(self, device, dtype):
|
||||||
x_p = make_tensor((3, 1), device=device, dtype=dtype)
|
x_p = make_tensor((3, 1), device=device, dtype=dtype)
|
||||||
|
@ -1781,7 +1781,7 @@ def linearize(func: Callable, *primals) -> Tuple[Any, Callable]:
|
|||||||
duals = tree_unflatten(flat_duals, primals_argspec)
|
duals = tree_unflatten(flat_duals, primals_argspec)
|
||||||
output = func(*duals)
|
output = func(*duals)
|
||||||
tangents = tree_map_only(
|
tangents = tree_map_only(
|
||||||
torch.Tensor, lambda t: fwAD.unpack_dual(t)[1], output
|
torch.Tensor, lambda dual: safe_unpack_dual(dual, False)[1], output
|
||||||
)
|
)
|
||||||
|
|
||||||
return tangents
|
return tangents
|
||||||
|
@ -275,14 +275,14 @@ def split_const_subgraphs(
|
|||||||
node.replace_all_uses_with(folded_attrs)
|
node.replace_all_uses_with(folded_attrs)
|
||||||
break
|
break
|
||||||
|
|
||||||
split.graph.eliminate_dead_code()
|
|
||||||
|
|
||||||
# Finally, inline the non-constant submod (if it exists) into the split submod.
|
# Finally, inline the non-constant submod (if it exists) into the split submod.
|
||||||
# This is so that the original caller who may have passed in a graph module will
|
# This is so that the original caller who may have passed in a graph module will
|
||||||
# get back out a graph module whose graph is traced to the same granularity.
|
# get back out a graph module whose graph is traced to the same granularity.
|
||||||
if hasattr(split, non_const_mod_name):
|
if hasattr(split, non_const_mod_name):
|
||||||
_inline_module(split, non_const_mod_name)
|
_inline_module(split, non_const_mod_name)
|
||||||
|
|
||||||
|
split.graph.eliminate_dead_code()
|
||||||
|
|
||||||
return FoldedGraphModule(
|
return FoldedGraphModule(
|
||||||
split,
|
split,
|
||||||
split.graph,
|
split.graph,
|
||||||
|
Reference in New Issue
Block a user