[fx] make split_module work with keep_original_order=True and no-op graph (#141340)

Fixes https://github.com/pytorch/pytorch/issues/140014

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141340
Approved by: https://github.com/ezyang
This commit is contained in:
Kshiteej K
2024-11-24 06:41:30 +00:00
committed by PyTorch MergeBot
parent 4c1f50af5f
commit af47e05a96
2 changed files with 39 additions and 0 deletions

View File

@ -893,6 +893,34 @@ terrible spacing
x = torch.randn(50, 512)
torch.testing.assert_close(split(x), traced(x))
def test_split_module_keep_original_order_and_noop_graph(self):
# Verify that split_module returns a similar no-op graph
# for `keep_original_order={True|False}`.
def fn(x):
return (x,)
g = make_fx(fn, tracing_mode="fake")(torch.randn(3, 3))
# g.graph.print_tabular()
# opcode name target args kwargs
# ----------- ------ -------- --------- --------
# placeholder x_1 x_1 () {}
# output output output ((x_1,),) {}
def _test_split_graph(split_gm):
# Verify that the split_gm has same structure as original
self.assertEqual(len(split_gm.graph.nodes), 2)
nodes = list(split_gm.graph.nodes)
self.assertEqual(nodes[0].op, "placeholder")
self.assertEqual(nodes[1].op, "output")
# `keep_original_order=False`
_test_split_graph(split_module(g, None, split_callback=lambda _ : 0, keep_original_order=False))
# `keep_original_order=True`
_test_split_graph(split_module(g, None, split_callback=lambda _ : 0, keep_original_order=True))
def test_normalize_binary_operators(self):
ops_to_test = {
torch.add,

View File

@ -601,6 +601,17 @@ def split_module(
elif num_outputs == 1:
base_mod_env[next(iter(partition.outputs))] = output_val
# When keep_original_order=True and if the graph doesn't have any
# `call_function` node then `base_mod_graph`, `base_mod_env` and `base_mod_attrs`
# are never populated.
# For this case, we call `construct_graph` here which takes care of updating them.
if keep_original_order and not base_mod_env:
for node in m.graph.nodes:
base_mod_env, base_mod_attrs = construct_graph(
node, base_mod_env, base_mod_attrs
)
# Add output node to `base_mod_graph` (i.e. the split graph) which will be returned.
for node in m.graph.nodes:
if node.op == "output":
base_mod_graph.output(