mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[fx] make split_module work with keep_original_order=True and no-op graph (#141340)
Fixes https://github.com/pytorch/pytorch/issues/140014 Pull Request resolved: https://github.com/pytorch/pytorch/pull/141340 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
4c1f50af5f
commit
af47e05a96
@ -893,6 +893,34 @@ terrible spacing
|
||||
x = torch.randn(50, 512)
|
||||
torch.testing.assert_close(split(x), traced(x))
|
||||
|
||||
def test_split_module_keep_original_order_and_noop_graph(self):
|
||||
# Verify that split_module returns a similar no-op graph
|
||||
# for `keep_original_order={True|False}`.
|
||||
def fn(x):
|
||||
return (x,)
|
||||
|
||||
g = make_fx(fn, tracing_mode="fake")(torch.randn(3, 3))
|
||||
|
||||
# g.graph.print_tabular()
|
||||
# opcode name target args kwargs
|
||||
# ----------- ------ -------- --------- --------
|
||||
# placeholder x_1 x_1 () {}
|
||||
# output output output ((x_1,),) {}
|
||||
|
||||
def _test_split_graph(split_gm):
|
||||
# Verify that the split_gm has same structure as original
|
||||
self.assertEqual(len(split_gm.graph.nodes), 2)
|
||||
|
||||
nodes = list(split_gm.graph.nodes)
|
||||
self.assertEqual(nodes[0].op, "placeholder")
|
||||
self.assertEqual(nodes[1].op, "output")
|
||||
|
||||
# `keep_original_order=False`
|
||||
_test_split_graph(split_module(g, None, split_callback=lambda _ : 0, keep_original_order=False))
|
||||
|
||||
# `keep_original_order=True`
|
||||
_test_split_graph(split_module(g, None, split_callback=lambda _ : 0, keep_original_order=True))
|
||||
|
||||
def test_normalize_binary_operators(self):
|
||||
ops_to_test = {
|
||||
torch.add,
|
||||
|
@ -601,6 +601,17 @@ def split_module(
|
||||
elif num_outputs == 1:
|
||||
base_mod_env[next(iter(partition.outputs))] = output_val
|
||||
|
||||
# When keep_original_order=True and if the graph doesn't have any
|
||||
# `call_function` node then `base_mod_graph`, `base_mod_env` and `base_mod_attrs`
|
||||
# are never populated.
|
||||
# For this case, we call `construct_graph` here which takes care of updating them.
|
||||
if keep_original_order and not base_mod_env:
|
||||
for node in m.graph.nodes:
|
||||
base_mod_env, base_mod_attrs = construct_graph(
|
||||
node, base_mod_env, base_mod_attrs
|
||||
)
|
||||
|
||||
# Add output node to `base_mod_graph` (i.e. the split graph) which will be returned.
|
||||
for node in m.graph.nodes:
|
||||
if node.op == "output":
|
||||
base_mod_graph.output(
|
||||
|
Reference in New Issue
Block a user