Back out "Revert D49107540: [pytorch][PR] split by tag" (#109332)

Summary:
Original commit changeset: 6391a068640b

Original Phabricator Diff: D49107540

Test Plan: same as D49107540

Differential Revision: D49297522

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109332
Approved by: https://github.com/842974287
This commit is contained in:
Wenting Wang
2023-09-16 05:29:16 +00:00
committed by PyTorch MergeBot
parent 7bce7f50f3
commit 393fe9339a
4 changed files with 160 additions and 19 deletions

View File

@ -1,6 +1,6 @@
import copy
from dataclasses import dataclass, field
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Tuple, Union
import torch.fx
from torch.fx._compatibility import compatibility
@ -11,6 +11,7 @@ from .tools_common import NodeList
__all__ = ["getattr_recursive", "setattr_recursive", "Component", "split_by_tags"]
@compatibility(is_backward_compatible=False)
def getattr_recursive(obj, name):
for layer in name.split("."):
@ -57,11 +58,13 @@ class Component:
@compatibility(is_backward_compatible=False)
def split_by_tags(gm: torch.fx.GraphModule, tags: List[str]) -> torch.fx.GraphModule:
def split_by_tags(
gm: torch.fx.GraphModule, tags: List[str], return_fqn_mapping: bool = False
) -> Union[torch.fx.GraphModule, Tuple[torch.fx.GraphModule, Dict[str, str]]]:
"""
Splits a GraphModule using tags on its graph nodes. We honor the order of
tags. For example, we have tags = ["a", "b", "c"], the function will create
the initial submodules in the order of "a_0", "b_1", "c_2".
the initial submodules in the order of "a", "b", "c".
To set a tag:
gm.graph.nodes[idx].tag = "mytag"
@ -88,13 +91,13 @@ def split_by_tags(gm: torch.fx.GraphModule, tags: List[str]) -> torch.fx.GraphMo
Marking the node corresponding to in1 with the tag sc.REQUEST_ONLY.lower() results in the following split:
ro_0:
ro:
def forward(self, in1):
self = self.root
linear1 = self.linear1(in1)
return linear1
main_1:
main:
def forward(self, in2, linear1):
self = self.root
linear2 = self.linear2(in2)
@ -102,12 +105,17 @@ def split_by_tags(gm: torch.fx.GraphModule, tags: List[str]) -> torch.fx.GraphMo
linear3 = self.linear3(cat_1)
return linear3
main_0:
main:
def forward(self, in1, in2):
self = self.root
ro_0 = self.ro_0(in1)
main_1 = self.main_1(in2, ro_0)
return main_1
Returns:
split_gm: torch fx graph after split
orig_to_split_fqn_mapping: a map between the original fqn and the fqn
after split for call_module and get_attr.
"""
def flatten(x: torch.fx.node.Argument) -> NodeList:
@ -210,9 +218,7 @@ def split_by_tags(gm: torch.fx.GraphModule, tags: List[str]) -> torch.fx.GraphMo
comp.orig_inputs.append(x)
placeholder = comp.graph.placeholder(x.name, type_expr=x.type)
placeholder.meta = copy.copy(x.meta)
comp.input_placeholders.append(
placeholder
)
comp.input_placeholders.append(placeholder)
used_in_main[x] = None
return comp.input_placeholders[comp.orig_inputs.index(x)]
@ -243,6 +249,7 @@ def split_by_tags(gm: torch.fx.GraphModule, tags: List[str]) -> torch.fx.GraphMo
node_to_component[n].orig_outputs.append(n)
# Now we create a graphmodule for each component.
orig_to_split_fqn_mapping: Dict[str, str] = {}
for comp in all_components:
outs = tuple(map(node_remapping.__getitem__, comp.orig_outputs))
@ -252,7 +259,10 @@ def split_by_tags(gm: torch.fx.GraphModule, tags: List[str]) -> torch.fx.GraphMo
# ((output_0, output_1, ...)).
comp.graph.output(outs[0] if len(outs) == 1 else outs)
comp.gm = lift_subgraph_as_module(gm, comp.graph)
comp.gm, comp_orig_to_split_fqn_mapping = lift_subgraph_as_module(
gm, subgraph=comp.graph, comp_name=comp.name
)
orig_to_split_fqn_mapping.update(comp_orig_to_split_fqn_mapping)
# Create a call_module node in main graph.
main_node = main_g.call_module(
@ -277,4 +287,8 @@ def split_by_tags(gm: torch.fx.GraphModule, tags: List[str]) -> torch.fx.GraphMo
if x.op == "get_attr":
setattr(main_root, x.name, getattr_recursive(gm, x.target)) # type: ignore[arg-type]
return torch.fx.GraphModule(main_root, main_g)
result_gm = torch.fx.GraphModule(main_root, main_g)
if return_fqn_mapping:
return result_gm, orig_to_split_fqn_mapping
return result_gm