mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Easy] Refactor post grad application of passes (#139293)
Refactors GraphTransformObserver to hook into the bisect manager pass application. And reworks post grad passes to use it. Pull Request resolved: https://github.com/pytorch/pytorch/pull/139293 Approved by: https://github.com/exclamaforte ghstack dependencies: #139292
This commit is contained in:
committed by
PyTorch MergeBot
parent
5075046db2
commit
f93ebb2cf4
@ -1717,7 +1717,7 @@ class PatternMatcherPass:
|
||||
def __getitem__(self, item: Tuple[str, torch.fx.node.Target]) -> List[PatternEntry]:
|
||||
return self.patterns[item]
|
||||
|
||||
def apply(self, gm: torch.fx.GraphModule) -> int:
|
||||
def apply(self, gm: Union[torch.fx.GraphModule, torch.fx.Graph]) -> int:
|
||||
if not self.patterns:
|
||||
return 0
|
||||
if isinstance(gm, torch.fx.GraphModule):
|
||||
@ -1745,6 +1745,7 @@ class PatternMatcherPass:
|
||||
if has_call_module:
|
||||
nodes.append(graph.find_nodes(op="call_module", sort=False))
|
||||
pass_name = self.pass_name if self.pass_name is not None else "pattern_matcher"
|
||||
assert isinstance(gm, torch.fx.GraphModule)
|
||||
with GraphTransformObserver(gm, pass_name):
|
||||
for node in sorted(itertools.chain.from_iterable(nodes), reverse=True):
|
||||
target = extract_target(node)
|
||||
|
Reference in New Issue
Block a user