[Easy] Refactor post grad application of passes (#139293)

Refactors GraphTransformObserver to hook into the bisect manager pass application. And reworks post grad passes to use it.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139293
Approved by: https://github.com/exclamaforte
ghstack dependencies: #139292
This commit is contained in:
eellison
2024-10-30 15:19:24 -07:00
committed by PyTorch MergeBot
parent 5075046db2
commit f93ebb2cf4
3 changed files with 82 additions and 49 deletions

View File

@ -1717,7 +1717,7 @@ class PatternMatcherPass:
def __getitem__(self, item: Tuple[str, torch.fx.node.Target]) -> List[PatternEntry]:
return self.patterns[item]
def apply(self, gm: torch.fx.GraphModule) -> int:
def apply(self, gm: Union[torch.fx.GraphModule, torch.fx.Graph]) -> int:
if not self.patterns:
return 0
if isinstance(gm, torch.fx.GraphModule):
@ -1745,6 +1745,7 @@ class PatternMatcherPass:
if has_call_module:
nodes.append(graph.find_nodes(op="call_module", sort=False))
pass_name = self.pass_name if self.pass_name is not None else "pattern_matcher"
assert isinstance(gm, torch.fx.GraphModule)
with GraphTransformObserver(gm, pass_name):
for node in sorted(itertools.chain.from_iterable(nodes), reverse=True):
target = extract_target(node)