[fx] Allow customization of submod name in split graph (#164035)

Fixes #164030: HOP and pipelining both name things submod_i
by adding an optional argument `partition_affix` to `split_module` API.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164035
Approved by: https://github.com/ezyang
ghstack dependencies: #164045
This commit is contained in:
Ke Wen
2025-09-28 13:29:46 -07:00
committed by PyTorch MergeBot
parent 4fd70d4e7b
commit 615da7b95e
2 changed files with 12 additions and 2 deletions

View File

@ -63,7 +63,7 @@ torch.fx.node.map_aggregate(a: torch.fx.node.Argument, fn: Callable[[torch.fx.no
torch.fx.node.map_arg(a: torch.fx.node.Argument, fn: Callable[[torch.fx.node.Node], torch.fx.node.Argument]) -> torch.fx.node.Argument
torch.fx.passes.reinplace.reinplace(gm, *sample_args)
torch.fx.passes.runtime_assert.insert_deferred_runtime_asserts(gm: torch.fx.graph_module.GraphModule, shape_env: Any, name: str, export: bool = False) -> None
torch.fx.passes.split_module.split_module(m: torch.fx.graph_module.GraphModule, root_m: torch.nn.modules.module.Module, split_callback: Callable[[torch.fx.node.Node], int], qualname_map: Optional[Dict[str, str]] = None, keep_original_order: Optional[bool] = False, keep_original_node_name: Optional[bool] = False, keep_original_input_name: bool = True)
torch.fx.passes.split_module.split_module(m: torch.fx.graph_module.GraphModule, root_m: torch.nn.modules.module.Module, split_callback: Callable[[torch.fx.node.Node], int], qualname_map: Optional[Dict[str, str]] = None, keep_original_order: Optional[bool] = False, keep_original_node_name: Optional[bool] = False, keep_original_input_name: bool = True, partition_affix: Optional[str] = None)
torch.fx.proxy.Attribute.__init__(self, root: torch.fx.proxy.Proxy, attr: str)
torch.fx.proxy.Proxy.__init__(self, node: torch.fx.node.Node, tracer: 'Optional[TracerBase]' = None)
torch.fx.proxy.Proxy.keys(self)

View File

@ -59,6 +59,8 @@ def split_module(
keep_original_order: Optional[bool] = False,
keep_original_node_name: Optional[bool] = False,
keep_original_input_name: bool = True,
*,
partition_affix: Optional[str] = None,
):
"""
Creates subgraphs out of main graph
@ -81,6 +83,8 @@ def split_module(
have the same node names as the original graph.
keep_original_input_name: bool: If the partitioned graphs should
have the same input names as the original graph.
partition_affix: Optional[str]: If specified, the submodules' names will contain
the affix, e.g. "submod_<affix>_<idx>".
Returns:
GraphModule: the module after split.
@ -253,7 +257,13 @@ def split_module(
use_partition.dependencies.setdefault(defined)
def instantiate_node_partition_mapping(node):
partition_name = str(split_callback(node))
partition_idx = split_callback(node)
partition_name = str(partition_idx)
if partition_affix is not None:
# For example, if user specifies partition_affix = "pp", then the
# partition name will be "pp_0", "pp_1", etc
partition_name = "_".join([partition_affix, partition_name])
log.debug(
"instantiate_node_partition_mapping %s (%s)", node.name, partition_name
)