mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
[Easy] Refactor post grad application of passes (#139293)
Refactors GraphTransformObserver to hook into the bisect manager pass application. And reworks post grad passes to use it. Pull Request resolved: https://github.com/pytorch/pytorch/pull/139293 Approved by: https://github.com/exclamaforte ghstack dependencies: #139292
This commit is contained in:
committed by
PyTorch MergeBot
parent
5075046db2
commit
f93ebb2cf4
@ -1,10 +1,11 @@
|
||||
# mypy: allow-untyped-decorators
|
||||
# mypy: allow-untyped-defs
|
||||
import functools
|
||||
import itertools
|
||||
import logging
|
||||
import operator
|
||||
from collections import Counter, defaultdict
|
||||
from typing import Any, Callable, Dict, List, Optional, Set
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
import torch
|
||||
import torch._inductor as inductor
|
||||
@ -17,7 +18,6 @@ from torch._inductor.virtualized import ops
|
||||
from torch._prims_common import is_boolean_dtype, is_expandable_to, is_integer_dtype
|
||||
from torch._utils_internal import upload_graph
|
||||
from torch.fx.experimental.symbolic_shapes import statically_known_true, sym_eq
|
||||
from torch.fx.passes.graph_transform_observer import GraphTransformObserver
|
||||
|
||||
from .. import config, ir, pattern_matcher
|
||||
from ..codegen.common import BackendFeature, has_backend_feature
|
||||
@ -65,19 +65,6 @@ pass_patterns = [
|
||||
]
|
||||
|
||||
|
||||
def apply_pass(pass_fn: Callable[[], object], name: Optional[str] = None) -> None:
|
||||
# TODO - we should just make this part of GraphTransformObserver
|
||||
from torch._inductor.bisect_helper import BisectionManager
|
||||
|
||||
debug_info: Optional[Callable[[], str]] = None
|
||||
if name is not None:
|
||||
debug_info = lambda: name # noqa: E731
|
||||
|
||||
if BisectionManager.disable_subsystem("inductor", "post_grad_passes", debug_info):
|
||||
return
|
||||
pass_fn()
|
||||
|
||||
|
||||
def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
|
||||
"""
|
||||
Passes that run on after grad. This is called once on the forwards
|
||||
@ -85,6 +72,11 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
|
||||
|
||||
The IR here has been normalized and functionalized.
|
||||
"""
|
||||
GraphTransformObserver = functools.partial(
|
||||
torch.fx.passes.graph_transform_observer.GraphTransformObserver,
|
||||
subsystem="post_grad_passes",
|
||||
)
|
||||
|
||||
if not torch._dynamo.config.skip_fsdp_hooks:
|
||||
remove_fsdp2_unsharded_param_graph_input_usage(gm.graph)
|
||||
|
||||
@ -93,26 +85,28 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
|
||||
gm.graph.eliminate_dead_code()
|
||||
|
||||
if is_inference and config.reorder_for_locality:
|
||||
apply_pass(lambda: reorder_for_locality(gm.graph), "reorder_for_locality")
|
||||
GraphTransformObserver(gm, "reorder_for_locality").apply_graph_pass(
|
||||
reorder_for_locality
|
||||
)
|
||||
|
||||
fake_tensor_updater = FakeTensorUpdater(gm.graph)
|
||||
|
||||
if post_grad_custom_pre_pass := config.post_grad_custom_pre_pass:
|
||||
with GraphTransformObserver(gm, "post_grad_custom_pre_pass"):
|
||||
apply_pass(
|
||||
lambda: post_grad_custom_pre_pass(gm.graph), "post_grad_custom_pre_pass"
|
||||
GraphTransformObserver(gm, "post_grad_custom_pre_pass").apply_graph_pass(
|
||||
post_grad_custom_pre_pass
|
||||
)
|
||||
|
||||
if config.pattern_matcher:
|
||||
lazy_init()
|
||||
optimus_scuba_log["before_recompile_post_grad"] = upload_graph(gm.graph)
|
||||
apply_pass(
|
||||
lambda: group_batch_fusion_passes(gm.graph, pre_grad=False),
|
||||
"group_batch_fusion_passes",
|
||||
GraphTransformObserver(gm, "post_grad_custom_pre_pass").apply_graph_pass(
|
||||
functools.partial(group_batch_fusion_passes, pre_grad=False)
|
||||
)
|
||||
apply_pass(lambda: remove_noop_ops(gm.graph), "remove_noop_ops")
|
||||
GraphTransformObserver(gm, "remove_noop_ops").apply_graph_pass(remove_noop_ops)
|
||||
for i, patterns in enumerate(pass_patterns):
|
||||
apply_pass(lambda: patterns.apply(gm.graph), f"pass_pattern_{i}") # type: ignore[arg-type]
|
||||
GraphTransformObserver(gm, f"pass_pattern_{i}").apply_graph_pass(
|
||||
patterns.apply
|
||||
)
|
||||
for pass_name in config.post_grad_fusion_options:
|
||||
# skip all patterns for group batch fusions
|
||||
if pass_name in POST_GRAD_FUSIONS:
|
||||
@ -121,7 +115,9 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
|
||||
inductor_before_change = save_inductor_dict(
|
||||
[pattern_matcher_pass.pass_name]
|
||||
)
|
||||
apply_pass(lambda: pattern_matcher_pass.apply(gm.graph), pass_name) # type: ignore[arg-type]
|
||||
GraphTransformObserver(gm, pass_name).apply_graph_pass(
|
||||
pattern_matcher_pass.apply
|
||||
)
|
||||
if not is_same_dict(counters["inductor"], inductor_before_change):
|
||||
optimus_scuba_log[
|
||||
f"{pattern_matcher_pass.pass_name}_post_grad"
|
||||
@ -133,37 +129,37 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
|
||||
micro_pipeline_tp_pass(gm.graph)
|
||||
|
||||
if config._fuse_ddp_communication:
|
||||
apply_pass(
|
||||
lambda: fuse_ddp_communication(
|
||||
gm.graph,
|
||||
GraphTransformObserver(gm, "fuse_ddp_communication").apply_graph_pass(
|
||||
lambda graph: fuse_ddp_communication(
|
||||
graph,
|
||||
config._fuse_ddp_communication_passes,
|
||||
config._fuse_ddp_bucket_size,
|
||||
),
|
||||
"fuse_ddp_communication",
|
||||
)
|
||||
)
|
||||
|
||||
if post_grad_custom_post_pass := config.post_grad_custom_post_pass:
|
||||
with GraphTransformObserver(gm, "post_grad_custom_post_pass"):
|
||||
apply_pass(
|
||||
lambda: post_grad_custom_post_pass(gm.graph),
|
||||
"post_grad_custom_post_pass",
|
||||
GraphTransformObserver(gm, "post_grad_custom_post_pass").apply_graph_pass(
|
||||
post_grad_custom_post_pass
|
||||
)
|
||||
|
||||
apply_pass(lambda: stable_topological_sort(gm.graph), "stable_sort")
|
||||
GraphTransformObserver(gm, "stable_sort").apply_graph_pass(stable_topological_sort)
|
||||
|
||||
apply_pass(lambda: move_constructors_to_gpu(gm.graph), "move_constructors_to_cuda")
|
||||
GraphTransformObserver(gm, "move_constructors_to_cuda").apply_graph_pass(
|
||||
move_constructors_to_gpu
|
||||
)
|
||||
|
||||
fake_tensor_updater.incremental_update()
|
||||
|
||||
# Keep these last, since they introduces mutation. Look at
|
||||
# ./fx_passes/README.md for a discussion of mutation invariants.
|
||||
apply_pass(lambda: reinplace_inplaceable_ops(gm.graph), "reinplace_inplaceable_ops")
|
||||
apply_pass(
|
||||
lambda: decompose_auto_functionalized(gm.graph), "decompose_auto_functionalized"
|
||||
GraphTransformObserver(gm, "reinplace_inplaceable_ops").apply_graph_pass(
|
||||
reinplace_inplaceable_ops
|
||||
)
|
||||
|
||||
apply_pass(
|
||||
lambda: comms.reinplace_fsdp_all_gather(gm.graph), "reinplace_fsdp_all_gather"
|
||||
GraphTransformObserver(gm, "decompose_auto_functionalized").apply_graph_pass(
|
||||
decompose_auto_functionalized
|
||||
)
|
||||
GraphTransformObserver(gm, "reinplace_fsdp_all_gather").apply_graph_pass(
|
||||
comms.reinplace_fsdp_all_gather
|
||||
)
|
||||
|
||||
gm.recompile()
|
||||
|
@ -1717,7 +1717,7 @@ class PatternMatcherPass:
|
||||
def __getitem__(self, item: Tuple[str, torch.fx.node.Target]) -> List[PatternEntry]:
|
||||
return self.patterns[item]
|
||||
|
||||
def apply(self, gm: torch.fx.GraphModule) -> int:
|
||||
def apply(self, gm: Union[torch.fx.GraphModule, torch.fx.Graph]) -> int:
|
||||
if not self.patterns:
|
||||
return 0
|
||||
if isinstance(gm, torch.fx.GraphModule):
|
||||
@ -1745,6 +1745,7 @@ class PatternMatcherPass:
|
||||
if has_call_module:
|
||||
nodes.append(graph.find_nodes(op="call_module", sort=False))
|
||||
pass_name = self.pass_name if self.pass_name is not None else "pattern_matcher"
|
||||
assert isinstance(gm, torch.fx.GraphModule)
|
||||
with GraphTransformObserver(gm, pass_name):
|
||||
for node in sorted(itertools.chain.from_iterable(nodes), reverse=True):
|
||||
target = extract_target(node)
|
||||
|
@ -1,10 +1,15 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import os
|
||||
from typing import Optional
|
||||
from typing import Callable, Optional, TypeVar
|
||||
|
||||
from torch.fx import Graph
|
||||
from torch.fx._compatibility import compatibility
|
||||
from torch.fx.graph_module import GraphModule
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
from .graph_drawer import FxGraphDrawer
|
||||
|
||||
|
||||
@ -16,12 +21,20 @@ class GraphTransformObserver:
|
||||
__pass_count = 0
|
||||
|
||||
def __init__(
|
||||
self, gm: GraphModule, passname: str, *, log_url: Optional[str] = None
|
||||
self,
|
||||
gm: GraphModule,
|
||||
passname: str,
|
||||
subsystem: Optional[str] = None,
|
||||
log_url: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
log_url is inferred to be torch._inductor.config.trace.log_url_for_graph_xform unless otherwise specified
|
||||
"""
|
||||
|
||||
self.gm = gm
|
||||
self.passname = passname
|
||||
self.subsystem = subsystem
|
||||
|
||||
# If log_url is None, we don't log anything
|
||||
if log_url is None:
|
||||
from torch._inductor.config import trace
|
||||
@ -32,8 +45,6 @@ class GraphTransformObserver:
|
||||
if self.log_url is None:
|
||||
return
|
||||
GraphTransformObserver.__pass_count += 1
|
||||
self.gm = gm
|
||||
self.passname = passname
|
||||
|
||||
self.input_dot_graph = FxGraphDrawer(
|
||||
self.gm,
|
||||
@ -46,6 +57,31 @@ class GraphTransformObserver:
|
||||
def get_current_pass_count(cls):
|
||||
return cls.__pass_count
|
||||
|
||||
def apply_gm_pass(self, pass_fn: Callable[[GraphModule], T]) -> Optional[T]:
|
||||
with self:
|
||||
if not self._check_disable_pass():
|
||||
return pass_fn(self.gm)
|
||||
|
||||
return None
|
||||
|
||||
def apply_graph_pass(self, pass_fn: Callable[[Graph], T]) -> Optional[T]:
|
||||
with self:
|
||||
if not self._check_disable_pass():
|
||||
return pass_fn(self.gm.graph)
|
||||
|
||||
return None
|
||||
|
||||
def _check_disable_pass(self):
|
||||
if self.subsystem is None:
|
||||
return False
|
||||
|
||||
debug_info = lambda: self.passname # noqa: E731
|
||||
from torch._inductor.bisect_helper import BisectionManager
|
||||
|
||||
return BisectionManager.disable_subsystem(
|
||||
"inductor", self.subsystem, debug_info
|
||||
)
|
||||
|
||||
def __enter__(self):
|
||||
if self.log_url is None or self.gm is None:
|
||||
return self
|
||||
|
Reference in New Issue
Block a user