[Easy] Refactor post grad application of passes (#139293)

Refactors GraphTransformObserver to hook into the bisect manager pass application. And reworks post grad passes to use it.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139293
Approved by: https://github.com/exclamaforte
ghstack dependencies: #139292
This commit is contained in:
eellison
2024-10-30 15:19:24 -07:00
committed by PyTorch MergeBot
parent 5075046db2
commit f93ebb2cf4
3 changed files with 82 additions and 49 deletions

View File

@ -1,10 +1,11 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import functools
import itertools
import logging
import operator
from collections import Counter, defaultdict
from typing import Any, Callable, Dict, List, Optional, Set
from typing import Any, Dict, List, Optional, Set
import torch
import torch._inductor as inductor
@ -17,7 +18,6 @@ from torch._inductor.virtualized import ops
from torch._prims_common import is_boolean_dtype, is_expandable_to, is_integer_dtype
from torch._utils_internal import upload_graph
from torch.fx.experimental.symbolic_shapes import statically_known_true, sym_eq
from torch.fx.passes.graph_transform_observer import GraphTransformObserver
from .. import config, ir, pattern_matcher
from ..codegen.common import BackendFeature, has_backend_feature
@ -65,19 +65,6 @@ pass_patterns = [
]
def apply_pass(pass_fn: Callable[[], object], name: Optional[str] = None) -> None:
# TODO - we should just make this part of GraphTransformObserver
from torch._inductor.bisect_helper import BisectionManager
debug_info: Optional[Callable[[], str]] = None
if name is not None:
debug_info = lambda: name # noqa: E731
if BisectionManager.disable_subsystem("inductor", "post_grad_passes", debug_info):
return
pass_fn()
def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
"""
Passes that run on after grad. This is called once on the forwards
@ -85,6 +72,11 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
The IR here has been normalized and functionalized.
"""
GraphTransformObserver = functools.partial(
torch.fx.passes.graph_transform_observer.GraphTransformObserver,
subsystem="post_grad_passes",
)
if not torch._dynamo.config.skip_fsdp_hooks:
remove_fsdp2_unsharded_param_graph_input_usage(gm.graph)
@ -93,26 +85,28 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
gm.graph.eliminate_dead_code()
if is_inference and config.reorder_for_locality:
apply_pass(lambda: reorder_for_locality(gm.graph), "reorder_for_locality")
GraphTransformObserver(gm, "reorder_for_locality").apply_graph_pass(
reorder_for_locality
)
fake_tensor_updater = FakeTensorUpdater(gm.graph)
if post_grad_custom_pre_pass := config.post_grad_custom_pre_pass:
with GraphTransformObserver(gm, "post_grad_custom_pre_pass"):
apply_pass(
lambda: post_grad_custom_pre_pass(gm.graph), "post_grad_custom_pre_pass"
GraphTransformObserver(gm, "post_grad_custom_pre_pass").apply_graph_pass(
post_grad_custom_pre_pass
)
if config.pattern_matcher:
lazy_init()
optimus_scuba_log["before_recompile_post_grad"] = upload_graph(gm.graph)
apply_pass(
lambda: group_batch_fusion_passes(gm.graph, pre_grad=False),
"group_batch_fusion_passes",
GraphTransformObserver(gm, "post_grad_custom_pre_pass").apply_graph_pass(
functools.partial(group_batch_fusion_passes, pre_grad=False)
)
apply_pass(lambda: remove_noop_ops(gm.graph), "remove_noop_ops")
GraphTransformObserver(gm, "remove_noop_ops").apply_graph_pass(remove_noop_ops)
for i, patterns in enumerate(pass_patterns):
apply_pass(lambda: patterns.apply(gm.graph), f"pass_pattern_{i}") # type: ignore[arg-type]
GraphTransformObserver(gm, f"pass_pattern_{i}").apply_graph_pass(
patterns.apply
)
for pass_name in config.post_grad_fusion_options:
# skip all patterns for group batch fusions
if pass_name in POST_GRAD_FUSIONS:
@ -121,7 +115,9 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
inductor_before_change = save_inductor_dict(
[pattern_matcher_pass.pass_name]
)
apply_pass(lambda: pattern_matcher_pass.apply(gm.graph), pass_name) # type: ignore[arg-type]
GraphTransformObserver(gm, pass_name).apply_graph_pass(
pattern_matcher_pass.apply
)
if not is_same_dict(counters["inductor"], inductor_before_change):
optimus_scuba_log[
f"{pattern_matcher_pass.pass_name}_post_grad"
@ -133,37 +129,37 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
micro_pipeline_tp_pass(gm.graph)
if config._fuse_ddp_communication:
apply_pass(
lambda: fuse_ddp_communication(
gm.graph,
GraphTransformObserver(gm, "fuse_ddp_communication").apply_graph_pass(
lambda graph: fuse_ddp_communication(
graph,
config._fuse_ddp_communication_passes,
config._fuse_ddp_bucket_size,
),
"fuse_ddp_communication",
)
)
if post_grad_custom_post_pass := config.post_grad_custom_post_pass:
with GraphTransformObserver(gm, "post_grad_custom_post_pass"):
apply_pass(
lambda: post_grad_custom_post_pass(gm.graph),
"post_grad_custom_post_pass",
GraphTransformObserver(gm, "post_grad_custom_post_pass").apply_graph_pass(
post_grad_custom_post_pass
)
apply_pass(lambda: stable_topological_sort(gm.graph), "stable_sort")
GraphTransformObserver(gm, "stable_sort").apply_graph_pass(stable_topological_sort)
apply_pass(lambda: move_constructors_to_gpu(gm.graph), "move_constructors_to_cuda")
GraphTransformObserver(gm, "move_constructors_to_cuda").apply_graph_pass(
move_constructors_to_gpu
)
fake_tensor_updater.incremental_update()
# Keep these last, since they introduces mutation. Look at
# ./fx_passes/README.md for a discussion of mutation invariants.
apply_pass(lambda: reinplace_inplaceable_ops(gm.graph), "reinplace_inplaceable_ops")
apply_pass(
lambda: decompose_auto_functionalized(gm.graph), "decompose_auto_functionalized"
GraphTransformObserver(gm, "reinplace_inplaceable_ops").apply_graph_pass(
reinplace_inplaceable_ops
)
apply_pass(
lambda: comms.reinplace_fsdp_all_gather(gm.graph), "reinplace_fsdp_all_gather"
GraphTransformObserver(gm, "decompose_auto_functionalized").apply_graph_pass(
decompose_auto_functionalized
)
GraphTransformObserver(gm, "reinplace_fsdp_all_gather").apply_graph_pass(
comms.reinplace_fsdp_all_gather
)
gm.recompile()

View File

@ -1717,7 +1717,7 @@ class PatternMatcherPass:
def __getitem__(self, item: Tuple[str, torch.fx.node.Target]) -> List[PatternEntry]:
return self.patterns[item]
def apply(self, gm: torch.fx.GraphModule) -> int:
def apply(self, gm: Union[torch.fx.GraphModule, torch.fx.Graph]) -> int:
if not self.patterns:
return 0
if isinstance(gm, torch.fx.GraphModule):
@ -1745,6 +1745,7 @@ class PatternMatcherPass:
if has_call_module:
nodes.append(graph.find_nodes(op="call_module", sort=False))
pass_name = self.pass_name if self.pass_name is not None else "pattern_matcher"
assert isinstance(gm, torch.fx.GraphModule)
with GraphTransformObserver(gm, pass_name):
for node in sorted(itertools.chain.from_iterable(nodes), reverse=True):
target = extract_target(node)

View File

@ -1,10 +1,15 @@
# mypy: allow-untyped-defs
import os
from typing import Optional
from typing import Callable, Optional, TypeVar
from torch.fx import Graph
from torch.fx._compatibility import compatibility
from torch.fx.graph_module import GraphModule
T = TypeVar("T")
from .graph_drawer import FxGraphDrawer
@ -16,12 +21,20 @@ class GraphTransformObserver:
__pass_count = 0
def __init__(
self, gm: GraphModule, passname: str, *, log_url: Optional[str] = None
self,
gm: GraphModule,
passname: str,
subsystem: Optional[str] = None,
log_url: Optional[str] = None,
):
"""
log_url is inferred to be torch._inductor.config.trace.log_url_for_graph_xform unless otherwise specified
"""
self.gm = gm
self.passname = passname
self.subsystem = subsystem
# If log_url is None, we don't log anything
if log_url is None:
from torch._inductor.config import trace
@ -32,8 +45,6 @@ class GraphTransformObserver:
if self.log_url is None:
return
GraphTransformObserver.__pass_count += 1
self.gm = gm
self.passname = passname
self.input_dot_graph = FxGraphDrawer(
self.gm,
@ -46,6 +57,31 @@ class GraphTransformObserver:
def get_current_pass_count(cls):
return cls.__pass_count
def apply_gm_pass(self, pass_fn: Callable[[GraphModule], T]) -> Optional[T]:
with self:
if not self._check_disable_pass():
return pass_fn(self.gm)
return None
def apply_graph_pass(self, pass_fn: Callable[[Graph], T]) -> Optional[T]:
with self:
if not self._check_disable_pass():
return pass_fn(self.gm.graph)
return None
def _check_disable_pass(self):
if self.subsystem is None:
return False
debug_info = lambda: self.passname # noqa: E731
from torch._inductor.bisect_helper import BisectionManager
return BisectionManager.disable_subsystem(
"inductor", self.subsystem, debug_info
)
def __enter__(self):
if self.log_url is None or self.gm is None:
return self