[fx] make split_module work with keep_original_order=True and no-op graph (#141340)

Fixes https://github.com/pytorch/pytorch/issues/140014

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141340
Approved by: https://github.com/ezyang
This commit is contained in:
Kshiteej K
2024-11-24 06:41:30 +00:00
committed by PyTorch MergeBot
parent 4c1f50af5f
commit af47e05a96
2 changed files with 39 additions and 0 deletions

View File

@ -893,6 +893,34 @@ terrible spacing
x = torch.randn(50, 512)
torch.testing.assert_close(split(x), traced(x))
def test_split_module_keep_original_order_and_noop_graph(self):
# Verify that split_module returns a similar no-op graph
# for `keep_original_order={True|False}`.
def fn(x):
return (x,)
g = make_fx(fn, tracing_mode="fake")(torch.randn(3, 3))
# g.graph.print_tabular()
# opcode name target args kwargs
# ----------- ------ -------- --------- --------
# placeholder x_1 x_1 () {}
# output output output ((x_1,),) {}
def _test_split_graph(split_gm):
# Verify that the split_gm has same structure as original
self.assertEqual(len(split_gm.graph.nodes), 2)
nodes = list(split_gm.graph.nodes)
self.assertEqual(nodes[0].op, "placeholder")
self.assertEqual(nodes[1].op, "output")
# `keep_original_order=False`
_test_split_graph(split_module(g, None, split_callback=lambda _ : 0, keep_original_order=False))
# `keep_original_order=True`
_test_split_graph(split_module(g, None, split_callback=lambda _ : 0, keep_original_order=True))
def test_normalize_binary_operators(self):
ops_to_test = {
torch.add,