Add flag to fx.passes.split_module to normalize input names (#157733)

This is useful for vLLM, which runs AOTAutograd directly on graphs after
they have been split.

I created a new flag for this instead of reusing
`keep_original_node_name` (please let me know if you think I should reuse this).
The reasoning is:
- The names of the placeholder nodes is different from the targets of
  the placehoder nodes. The targets are the actual input names.
- Backwards compatibility: this API has been out for ~4 years, it
  looks public, and it has extensive public use. For example, this change
  would actually be BC-breaking to vLLM (they rely on the subgraph input
  names being different at the moment).

Test Plan:
- new tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157733
Approved by: https://github.com/ezyang
This commit is contained in:
rzou
2025-07-07 14:43:48 -07:00
committed by PyTorch MergeBot
parent 7381c77724
commit b9afdd9bcc
3 changed files with 65 additions and 12 deletions

View File

@ -64,7 +64,7 @@ torch.fx.node.map_aggregate(a: torch.fx.node.Argument, fn: Callable[[torch.fx.no
torch.fx.node.map_arg(a: torch.fx.node.Argument, fn: Callable[[torch.fx.node.Node], torch.fx.node.Argument]) -> torch.fx.node.Argument torch.fx.node.map_arg(a: torch.fx.node.Argument, fn: Callable[[torch.fx.node.Node], torch.fx.node.Argument]) -> torch.fx.node.Argument
torch.fx.passes.reinplace.reinplace(gm, *sample_args) torch.fx.passes.reinplace.reinplace(gm, *sample_args)
torch.fx.passes.runtime_assert.insert_deferred_runtime_asserts(gm: torch.fx.graph_module.GraphModule, shape_env: Any, name: str, export: bool = False) -> None torch.fx.passes.runtime_assert.insert_deferred_runtime_asserts(gm: torch.fx.graph_module.GraphModule, shape_env: Any, name: str, export: bool = False) -> None
torch.fx.passes.split_module.split_module(m: torch.fx.graph_module.GraphModule, root_m: torch.nn.modules.module.Module, split_callback: Callable[[torch.fx.node.Node], int], qualname_map: Optional[Dict[str, str]] = None, keep_original_order: Optional[bool] = False, keep_original_node_name: Optional[bool] = False) torch.fx.passes.split_module.split_module(m: torch.fx.graph_module.GraphModule, root_m: torch.nn.modules.module.Module, split_callback: Callable[[torch.fx.node.Node], int], qualname_map: Optional[Dict[str, str]] = None, keep_original_order: Optional[bool] = False, keep_original_node_name: Optional[bool] = False, keep_original_input_name: bool = True)
torch.fx.proxy.Attribute.__init__(self, root: torch.fx.proxy.Proxy, attr: str) torch.fx.proxy.Attribute.__init__(self, root: torch.fx.proxy.Proxy, attr: str)
torch.fx.proxy.Proxy.__init__(self, node: torch.fx.node.Node, tracer: 'Optional[TracerBase]' = None) torch.fx.proxy.Proxy.__init__(self, node: torch.fx.node.Node, tracer: 'Optional[TracerBase]' = None)
torch.fx.proxy.Proxy.keys(self) torch.fx.proxy.Proxy.keys(self)

View File

@ -791,6 +791,46 @@ terrible spacing
self.assertEqual(orig_out, submodules_out) self.assertEqual(orig_out, submodules_out)
def test_split_module_input_names(self):
class Mod(torch.nn.Module):
def forward(self, x, a0, a1, b0, b1, c0, c1):
x = x + (a0 ** 2) + (a1 / 2)
x = x + (b0 ** 2) + (b1 / 2)
x = x + (c0 ** 2) + (c1 / 2)
return x
mod = Mod()
traced = torch.fx.symbolic_trace(mod)
seen = 0
def split(n):
nonlocal seen
result = seen // 4
seen += 1
return result
split = split_module(traced, mod, split, keep_original_input_name=False)
# All the submodules should take in the inputs in the same order.
args = [torch.tensor(2.), torch.tensor(3.), torch.tensor(4.)]
output0 = split.submod_0(*args)
output1 = split.submod_1(*args)
output2 = split.submod_2(*args)
self.assertEqual(output0, output1)
self.assertEqual(output1, output2)
# Each submodule should have normalized input names
def check_ph(gm):
nodes = list(gm.graph.nodes)
self.assertEqual(nodes[0].target, "arg_0")
self.assertEqual(nodes[1].target, "arg_1")
self.assertEqual(nodes[2].target, "arg_2")
check_ph(split.submod_0)
check_ph(split.submod_1)
check_ph(split.submod_2)
def test_split_module_dead_code(self): def test_split_module_dead_code(self):
class ModWithDeadCode(torch.nn.Module): class ModWithDeadCode(torch.nn.Module):
def forward(self, x): def forward(self, x):

View File

@ -58,6 +58,7 @@ def split_module(
qualname_map: Optional[dict[str, str]] = None, qualname_map: Optional[dict[str, str]] = None,
keep_original_order: Optional[bool] = False, keep_original_order: Optional[bool] = False,
keep_original_node_name: Optional[bool] = False, keep_original_node_name: Optional[bool] = False,
keep_original_input_name: bool = True,
): ):
""" """
Creates subgraphs out of main graph Creates subgraphs out of main graph
@ -76,7 +77,10 @@ def split_module(
names in the original module. names in the original module.
keep_original_order: Optional[bool]: keep the original order of the GraphModule keep_original_order: Optional[bool]: keep the original order of the GraphModule
or use the Topological order of the new constructed GraphModule or use the Topological order of the new constructed GraphModule
keep_original_node_name: Optional[bool]: If the partitioned graphs should
have the same node names as the original graph.
keep_original_input_name: bool: If the partitioned graphs should
have the same input names as the original graph.
Returns: Returns:
GraphModule: the module after split. GraphModule: the module after split.
@ -419,11 +423,28 @@ def split_module(
for partition_name in sorted_partitions: for partition_name in sorted_partitions:
partition = partitions[partition_name] partition = partitions[partition_name]
new_inputs: dict[str, None] = {} new_inputs: dict[str, None] = {}
counter = 0
for inp in partition.inputs: for inp in partition.inputs:
orig_node = orig_nodes[inp] orig_node = orig_nodes[inp]
# We don't pass in get_attr nodes as inputs to the partition, but # We don't pass in get_attr nodes as inputs to the partition, but
# instead set them as targets and use getattr within the module # instead set them as targets and use getattr within the module
def add_placeholder():
if keep_original_input_name:
name = inp
else:
nonlocal counter
name = f"arg_{counter}"
counter += 1
placeholder = partition.graph.placeholder(
name,
type_expr=orig_nodes[inp].type,
)
new_inputs[inp] = None
return placeholder
if orig_node.op == "get_attr": if orig_node.op == "get_attr":
assert isinstance(orig_node.target, str) assert isinstance(orig_node.target, str)
@ -432,17 +453,9 @@ def split_module(
placeholder = partition.graph.get_attr(orig_node.target) placeholder = partition.graph.get_attr(orig_node.target)
partition.targets[orig_node.target] = orig_attr partition.targets[orig_node.target] = orig_attr
else: else:
placeholder = partition.graph.placeholder( placeholder = add_placeholder()
inp,
type_expr=orig_nodes[inp].type,
)
new_inputs[inp] = None
else: else:
placeholder = partition.graph.placeholder( placeholder = add_placeholder()
inp,
type_expr=orig_nodes[inp].type,
)
new_inputs[inp] = None
placeholder.meta = orig_nodes[inp].meta.copy() placeholder.meta = orig_nodes[inp].meta.copy()
partition.environment[orig_nodes[inp]] = placeholder partition.environment[orig_nodes[inp]] = placeholder
partition.inputs = new_inputs partition.inputs = new_inputs