mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[fx] Allow customization of submod name in split graph (#164035)
Fixes #164030: HOP and pipelining both name things submod_i by adding an optional argument `partition_affix` to `split_module` API. Pull Request resolved: https://github.com/pytorch/pytorch/pull/164035 Approved by: https://github.com/ezyang ghstack dependencies: #164045
This commit is contained in:
@ -63,7 +63,7 @@ torch.fx.node.map_aggregate(a: torch.fx.node.Argument, fn: Callable[[torch.fx.no
|
||||
torch.fx.node.map_arg(a: torch.fx.node.Argument, fn: Callable[[torch.fx.node.Node], torch.fx.node.Argument]) -> torch.fx.node.Argument
|
||||
torch.fx.passes.reinplace.reinplace(gm, *sample_args)
|
||||
torch.fx.passes.runtime_assert.insert_deferred_runtime_asserts(gm: torch.fx.graph_module.GraphModule, shape_env: Any, name: str, export: bool = False) -> None
|
||||
torch.fx.passes.split_module.split_module(m: torch.fx.graph_module.GraphModule, root_m: torch.nn.modules.module.Module, split_callback: Callable[[torch.fx.node.Node], int], qualname_map: Optional[Dict[str, str]] = None, keep_original_order: Optional[bool] = False, keep_original_node_name: Optional[bool] = False, keep_original_input_name: bool = True)
|
||||
torch.fx.passes.split_module.split_module(m: torch.fx.graph_module.GraphModule, root_m: torch.nn.modules.module.Module, split_callback: Callable[[torch.fx.node.Node], int], qualname_map: Optional[Dict[str, str]] = None, keep_original_order: Optional[bool] = False, keep_original_node_name: Optional[bool] = False, keep_original_input_name: bool = True, partition_affix: Optional[str] = None)
|
||||
torch.fx.proxy.Attribute.__init__(self, root: torch.fx.proxy.Proxy, attr: str)
|
||||
torch.fx.proxy.Proxy.__init__(self, node: torch.fx.node.Node, tracer: 'Optional[TracerBase]' = None)
|
||||
torch.fx.proxy.Proxy.keys(self)
|
||||
|
@ -59,6 +59,8 @@ def split_module(
|
||||
keep_original_order: Optional[bool] = False,
|
||||
keep_original_node_name: Optional[bool] = False,
|
||||
keep_original_input_name: bool = True,
|
||||
*,
|
||||
partition_affix: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Creates subgraphs out of main graph
|
||||
@ -81,6 +83,8 @@ def split_module(
|
||||
have the same node names as the original graph.
|
||||
keep_original_input_name: bool: If the partitioned graphs should
|
||||
have the same input names as the original graph.
|
||||
partition_affix: Optional[str]: If specified, the submodules' names will contain
|
||||
the affix, e.g. "submod_<affix>_<idx>".
|
||||
|
||||
Returns:
|
||||
GraphModule: the module after split.
|
||||
@ -253,7 +257,13 @@ def split_module(
|
||||
use_partition.dependencies.setdefault(defined)
|
||||
|
||||
def instantiate_node_partition_mapping(node):
|
||||
partition_name = str(split_callback(node))
|
||||
partition_idx = split_callback(node)
|
||||
partition_name = str(partition_idx)
|
||||
if partition_affix is not None:
|
||||
# For example, if user specifies partition_affix = "pp", then the
|
||||
# partition name will be "pp_0", "pp_1", etc
|
||||
partition_name = "_".join([partition_affix, partition_name])
|
||||
|
||||
log.debug(
|
||||
"instantiate_node_partition_mapping %s (%s)", node.name, partition_name
|
||||
)
|
||||
|
Reference in New Issue
Block a user