diff --git a/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect b/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect index cd7d6374f6da..fab0dbd06676 100644 --- a/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect +++ b/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect @@ -64,7 +64,7 @@ torch.fx.node.map_aggregate(a: torch.fx.node.Argument, fn: Callable[[torch.fx.no torch.fx.node.map_arg(a: torch.fx.node.Argument, fn: Callable[[torch.fx.node.Node], torch.fx.node.Argument]) -> torch.fx.node.Argument torch.fx.passes.reinplace.reinplace(gm, *sample_args) torch.fx.passes.runtime_assert.insert_deferred_runtime_asserts(gm: torch.fx.graph_module.GraphModule, shape_env: Any, name: str, export: bool = False) -> None -torch.fx.passes.split_module.split_module(m: torch.fx.graph_module.GraphModule, root_m: torch.nn.modules.module.Module, split_callback: Callable[[torch.fx.node.Node], int], qualname_map: Optional[Dict[str, str]] = None, keep_original_order: Optional[bool] = False, keep_original_node_name: Optional[bool] = False) +torch.fx.passes.split_module.split_module(m: torch.fx.graph_module.GraphModule, root_m: torch.nn.modules.module.Module, split_callback: Callable[[torch.fx.node.Node], int], qualname_map: Optional[Dict[str, str]] = None, keep_original_order: Optional[bool] = False, keep_original_node_name: Optional[bool] = False, keep_original_input_name: bool = True) torch.fx.proxy.Attribute.__init__(self, root: torch.fx.proxy.Proxy, attr: str) torch.fx.proxy.Proxy.__init__(self, node: torch.fx.node.Node, tracer: 'Optional[TracerBase]' = None) torch.fx.proxy.Proxy.keys(self) diff --git a/test/test_fx_experimental.py b/test/test_fx_experimental.py index 434de5243c13..91b574c9b04c 100644 --- a/test/test_fx_experimental.py +++ b/test/test_fx_experimental.py @@ -791,6 +791,46 @@ terrible spacing self.assertEqual(orig_out, submodules_out) + def test_split_module_input_names(self): + class Mod(torch.nn.Module): + def forward(self, x, a0, a1, b0, b1, c0, c1): + x = x + (a0 ** 2) + (a1 / 2) + x = x + (b0 ** 2) + (b1 / 2) + x = x + (c0 ** 2) + (c1 / 2) + return x + + mod = Mod() + traced = torch.fx.symbolic_trace(mod) + + seen = 0 + + def split(n): + nonlocal seen + result = seen // 4 + seen += 1 + return result + + split = split_module(traced, mod, split, keep_original_input_name=False) + + # All the submodules should take in the inputs in the same order. + args = [torch.tensor(2.), torch.tensor(3.), torch.tensor(4.)] + output0 = split.submod_0(*args) + output1 = split.submod_1(*args) + output2 = split.submod_2(*args) + self.assertEqual(output0, output1) + self.assertEqual(output1, output2) + + # Each submodule should have normalized input names + def check_ph(gm): + nodes = list(gm.graph.nodes) + self.assertEqual(nodes[0].target, "arg_0") + self.assertEqual(nodes[1].target, "arg_1") + self.assertEqual(nodes[2].target, "arg_2") + + check_ph(split.submod_0) + check_ph(split.submod_1) + check_ph(split.submod_2) + def test_split_module_dead_code(self): class ModWithDeadCode(torch.nn.Module): def forward(self, x): diff --git a/torch/fx/passes/split_module.py b/torch/fx/passes/split_module.py index 59c560423d40..413584070d13 100644 --- a/torch/fx/passes/split_module.py +++ b/torch/fx/passes/split_module.py @@ -58,6 +58,7 @@ def split_module( qualname_map: Optional[dict[str, str]] = None, keep_original_order: Optional[bool] = False, keep_original_node_name: Optional[bool] = False, + keep_original_input_name: bool = True, ): """ Creates subgraphs out of main graph @@ -76,7 +77,10 @@ def split_module( names in the original module. keep_original_order: Optional[bool]: keep the original order of the GraphModule or use the Topological order of the new constructed GraphModule - + keep_original_node_name: Optional[bool]: If the partitioned graphs should + have the same node names as the original graph. + keep_original_input_name: bool: If the partitioned graphs should + have the same input names as the original graph. Returns: GraphModule: the module after split. @@ -419,11 +423,28 @@ def split_module( for partition_name in sorted_partitions: partition = partitions[partition_name] new_inputs: dict[str, None] = {} + + counter = 0 + for inp in partition.inputs: orig_node = orig_nodes[inp] # We don't pass in get_attr nodes as inputs to the partition, but # instead set them as targets and use getattr within the module + def add_placeholder(): + if keep_original_input_name: + name = inp + else: + nonlocal counter + name = f"arg_{counter}" + counter += 1 + placeholder = partition.graph.placeholder( + name, + type_expr=orig_nodes[inp].type, + ) + new_inputs[inp] = None + return placeholder + if orig_node.op == "get_attr": assert isinstance(orig_node.target, str) @@ -432,17 +453,9 @@ def split_module( placeholder = partition.graph.get_attr(orig_node.target) partition.targets[orig_node.target] = orig_attr else: - placeholder = partition.graph.placeholder( - inp, - type_expr=orig_nodes[inp].type, - ) - new_inputs[inp] = None + placeholder = add_placeholder() else: - placeholder = partition.graph.placeholder( - inp, - type_expr=orig_nodes[inp].type, - ) - new_inputs[inp] = None + placeholder = add_placeholder() placeholder.meta = orig_nodes[inp].meta.copy() partition.environment[orig_nodes[inp]] = placeholder partition.inputs = new_inputs