mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Back out "Revert D49107540: [pytorch][PR] split by tag" (#109332)
Summary: Original commit changeset: 6391a068640b Original Phabricator Diff: D49107540 Test Plan: same as D49107540 Differential Revision: D49297522 Pull Request resolved: https://github.com/pytorch/pytorch/pull/109332 Approved by: https://github.com/842974287
This commit is contained in:
committed by
PyTorch MergeBot
parent
7bce7f50f3
commit
393fe9339a
@ -1,6 +1,6 @@
|
||||
import copy
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch.fx
|
||||
from torch.fx._compatibility import compatibility
|
||||
@ -11,6 +11,7 @@ from .tools_common import NodeList
|
||||
|
||||
__all__ = ["getattr_recursive", "setattr_recursive", "Component", "split_by_tags"]
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def getattr_recursive(obj, name):
|
||||
for layer in name.split("."):
|
||||
@ -57,11 +58,13 @@ class Component:
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def split_by_tags(gm: torch.fx.GraphModule, tags: List[str]) -> torch.fx.GraphModule:
|
||||
def split_by_tags(
|
||||
gm: torch.fx.GraphModule, tags: List[str], return_fqn_mapping: bool = False
|
||||
) -> Union[torch.fx.GraphModule, Tuple[torch.fx.GraphModule, Dict[str, str]]]:
|
||||
"""
|
||||
Splits a GraphModule using tags on its graph nodes. We honor the order of
|
||||
tags. For example, we have tags = ["a", "b", "c"], the function will create
|
||||
the initial submodules in the order of "a_0", "b_1", "c_2".
|
||||
the initial submodules in the order of "a", "b", "c".
|
||||
|
||||
To set a tag:
|
||||
gm.graph.nodes[idx].tag = "mytag"
|
||||
@ -88,13 +91,13 @@ def split_by_tags(gm: torch.fx.GraphModule, tags: List[str]) -> torch.fx.GraphMo
|
||||
|
||||
Marking the node corresponding to in1 with the tag sc.REQUEST_ONLY.lower() results in the following split:
|
||||
|
||||
ro_0:
|
||||
ro:
|
||||
def forward(self, in1):
|
||||
self = self.root
|
||||
linear1 = self.linear1(in1)
|
||||
return linear1
|
||||
|
||||
main_1:
|
||||
main:
|
||||
def forward(self, in2, linear1):
|
||||
self = self.root
|
||||
linear2 = self.linear2(in2)
|
||||
@ -102,12 +105,17 @@ def split_by_tags(gm: torch.fx.GraphModule, tags: List[str]) -> torch.fx.GraphMo
|
||||
linear3 = self.linear3(cat_1)
|
||||
return linear3
|
||||
|
||||
main_0:
|
||||
main:
|
||||
def forward(self, in1, in2):
|
||||
self = self.root
|
||||
ro_0 = self.ro_0(in1)
|
||||
main_1 = self.main_1(in2, ro_0)
|
||||
return main_1
|
||||
|
||||
Returns:
|
||||
split_gm: torch fx graph after split
|
||||
orig_to_split_fqn_mapping: a map between the original fqn and the fqn
|
||||
after split for call_module and get_attr.
|
||||
"""
|
||||
|
||||
def flatten(x: torch.fx.node.Argument) -> NodeList:
|
||||
@ -210,9 +218,7 @@ def split_by_tags(gm: torch.fx.GraphModule, tags: List[str]) -> torch.fx.GraphMo
|
||||
comp.orig_inputs.append(x)
|
||||
placeholder = comp.graph.placeholder(x.name, type_expr=x.type)
|
||||
placeholder.meta = copy.copy(x.meta)
|
||||
comp.input_placeholders.append(
|
||||
placeholder
|
||||
)
|
||||
comp.input_placeholders.append(placeholder)
|
||||
used_in_main[x] = None
|
||||
|
||||
return comp.input_placeholders[comp.orig_inputs.index(x)]
|
||||
@ -243,6 +249,7 @@ def split_by_tags(gm: torch.fx.GraphModule, tags: List[str]) -> torch.fx.GraphMo
|
||||
node_to_component[n].orig_outputs.append(n)
|
||||
|
||||
# Now we create a graphmodule for each component.
|
||||
orig_to_split_fqn_mapping: Dict[str, str] = {}
|
||||
for comp in all_components:
|
||||
outs = tuple(map(node_remapping.__getitem__, comp.orig_outputs))
|
||||
|
||||
@ -252,7 +259,10 @@ def split_by_tags(gm: torch.fx.GraphModule, tags: List[str]) -> torch.fx.GraphMo
|
||||
# ((output_0, output_1, ...)).
|
||||
comp.graph.output(outs[0] if len(outs) == 1 else outs)
|
||||
|
||||
comp.gm = lift_subgraph_as_module(gm, comp.graph)
|
||||
comp.gm, comp_orig_to_split_fqn_mapping = lift_subgraph_as_module(
|
||||
gm, subgraph=comp.graph, comp_name=comp.name
|
||||
)
|
||||
orig_to_split_fqn_mapping.update(comp_orig_to_split_fqn_mapping)
|
||||
|
||||
# Create a call_module node in main graph.
|
||||
main_node = main_g.call_module(
|
||||
@ -277,4 +287,8 @@ def split_by_tags(gm: torch.fx.GraphModule, tags: List[str]) -> torch.fx.GraphMo
|
||||
if x.op == "get_attr":
|
||||
setattr(main_root, x.name, getattr_recursive(gm, x.target)) # type: ignore[arg-type]
|
||||
|
||||
return torch.fx.GraphModule(main_root, main_g)
|
||||
result_gm = torch.fx.GraphModule(main_root, main_g)
|
||||
if return_fqn_mapping:
|
||||
return result_gm, orig_to_split_fqn_mapping
|
||||
|
||||
return result_gm
|
||||
|
Reference in New Issue
Block a user