From af47e05a96d35ce789c96817196d572da8a52eb4 Mon Sep 17 00:00:00 2001 From: Kshiteej K Date: Sun, 24 Nov 2024 06:41:30 +0000 Subject: [PATCH] [fx] make split_module work with keep_original_order=True and no-op graph (#141340) Fixes https://github.com/pytorch/pytorch/issues/140014 Pull Request resolved: https://github.com/pytorch/pytorch/pull/141340 Approved by: https://github.com/ezyang --- test/test_fx_experimental.py | 28 ++++++++++++++++++++++++++++ torch/fx/passes/split_module.py | 11 +++++++++++ 2 files changed, 39 insertions(+) diff --git a/test/test_fx_experimental.py b/test/test_fx_experimental.py index 8d842f101cd4..40cc6f1ad11a 100644 --- a/test/test_fx_experimental.py +++ b/test/test_fx_experimental.py @@ -893,6 +893,34 @@ terrible spacing x = torch.randn(50, 512) torch.testing.assert_close(split(x), traced(x)) + def test_split_module_keep_original_order_and_noop_graph(self): + # Verify that split_module returns a similar no-op graph + # for `keep_original_order={True|False}`. + def fn(x): + return (x,) + + g = make_fx(fn, tracing_mode="fake")(torch.randn(3, 3)) + + # g.graph.print_tabular() + # opcode name target args kwargs + # ----------- ------ -------- --------- -------- + # placeholder x_1 x_1 () {} + # output output output ((x_1,),) {} + + def _test_split_graph(split_gm): + # Verify that the split_gm has same structure as original + self.assertEqual(len(split_gm.graph.nodes), 2) + + nodes = list(split_gm.graph.nodes) + self.assertEqual(nodes[0].op, "placeholder") + self.assertEqual(nodes[1].op, "output") + + # `keep_original_order=False` + _test_split_graph(split_module(g, None, split_callback=lambda _ : 0, keep_original_order=False)) + + # `keep_original_order=True` + _test_split_graph(split_module(g, None, split_callback=lambda _ : 0, keep_original_order=True)) + def test_normalize_binary_operators(self): ops_to_test = { torch.add, diff --git a/torch/fx/passes/split_module.py b/torch/fx/passes/split_module.py index 0495a9520f63..7fec3089c527 100644 --- a/torch/fx/passes/split_module.py +++ b/torch/fx/passes/split_module.py @@ -601,6 +601,17 @@ def split_module( elif num_outputs == 1: base_mod_env[next(iter(partition.outputs))] = output_val + # When keep_original_order=True and if the graph doesn't have any + # `call_function` node then `base_mod_graph`, `base_mod_env` and `base_mod_attrs` + # are never populated. + # For this case, we call `construct_graph` here which takes care of updating them. + if keep_original_order and not base_mod_env: + for node in m.graph.nodes: + base_mod_env, base_mod_attrs = construct_graph( + node, base_mod_env, base_mod_attrs + ) + + # Add output node to `base_mod_graph` (i.e. the split graph) which will be returned. for node in m.graph.nodes: if node.op == "output": base_mod_graph.output(