Files
pytorch/torch/fx/passes/graph_transform_observer.py
Yuanyuan Chen b11593c31b [8/N] Apply ruff UP035 rule (#165214)
This is follow-up of #164653 to continue applying `UP035` fixes. The purpose is to finally enable this rule.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165214
Approved by: https://github.com/ezyang
2025-10-15 03:18:57 +00:00

230 lines
7.6 KiB
Python

# mypy: allow-untyped-defs
import os
from collections.abc import Callable
from typing import Optional, TypeVar
from torch.fx import Graph, Node
from torch.fx._compatibility import compatibility
from torch.fx.graph_module import GraphModule
from torch.fx.traceback import NodeSource, NodeSourceAction
T = TypeVar("T")
from .graph_drawer import FxGraphDrawer
__all__ = ["GraphTransformObserver"]
@compatibility(is_backward_compatible=False)
class GraphTransformObserver:
__pass_count = 0
def __init__(
self,
gm: GraphModule,
passname: str,
subsystem: Optional[str] = None,
log_url: Optional[str] = None,
):
"""
log_url is inferred to be torch._inductor.config.trace.log_url_for_graph_xform unless otherwise specified
"""
from torch._inductor import config as inductor_config
self.gm = gm
self.passname = passname
self.subsystem = subsystem
if log_url is None:
log_url = inductor_config.trace.log_url_for_graph_xform
self.log_url = log_url
self.active = (
self.log_url is not None
or inductor_config.trace.provenance_tracking_level == 1
)
if self.active:
self.erased_nodes: set[str] = set()
self.created_nodes: set[str] = set()
self.name_to_node: dict[str, Node] = {}
# record graph modules deepcopied from self.gm, so we can remove hooks on them when exiting the context
self.copied_gms: list[GraphModule] = []
self._node_creation_hook = self.get_node_creation_hook()
self._node_erase_hook = self.get_node_erase_hook()
self._node_replace_hook = self.get_node_replace_hook()
self._deepcopy_hook = self.get_deepcopy_hook()
# If log_url is None, we don't log anything
if self.log_url is None:
return
GraphTransformObserver.__pass_count += 1
self.input_dot_graph = FxGraphDrawer(
self.gm,
self.passname,
ignore_getattr=True,
ignore_parameters_and_buffers=True,
).get_dot_graph()
@classmethod
def get_current_pass_count(cls):
return cls.__pass_count
def apply_gm_pass(self, pass_fn: Callable[[GraphModule], T]) -> Optional[T]:
with self:
if not self._check_disable_pass():
return pass_fn(self.gm)
return None
def apply_graph_pass(self, pass_fn: Callable[[Graph], T]) -> Optional[T]:
with self:
if not self._check_disable_pass():
return pass_fn(self.gm.graph)
return None
def _check_disable_pass(self):
if self.subsystem is None:
return False
debug_info = lambda: self.passname # noqa: E731
from torch._inductor.compiler_bisector import CompilerBisector
return CompilerBisector.disable_subsystem(
"inductor", self.subsystem, debug_info
)
def __enter__(self):
if not self.active:
return self
self.gm._register_create_node_hook(self._node_creation_hook)
self.gm._register_erase_node_hook(self._node_erase_hook)
self.gm._register_replace_node_hook(self._node_replace_hook)
self.gm._register_deepcopy_hook(self._deepcopy_hook)
self.erased_nodes.clear()
self.created_nodes.clear()
self.name_to_node.clear()
self.copied_gms.clear()
for node in self.gm.graph.nodes:
self.name_to_node[node.name] = node
return self
def __exit__(self, type, value, tb):
if not self.active:
return
for gm in self.copied_gms + [self.gm]:
gm._unregister_create_node_hook(self._node_creation_hook)
gm._unregister_erase_node_hook(self._node_erase_hook)
gm._unregister_replace_node_hook(self._node_replace_hook)
gm._unregister_deepcopy_hook(self._deepcopy_hook)
if self.log_url is None:
return
if len(self.created_nodes) > 0 or len(self.erased_nodes) > 0:
for e in self.input_dot_graph.get_node_list():
if e.get_name() in self.erased_nodes:
e.obj_dict["attributes"]["fillcolor"] = "yellow"
else:
e.obj_dict["attributes"]["fillcolor"] = "grey"
assert self.log_url is not None
self.input_dot_graph.write(
os.path.join(
self.log_url,
f"pass_{GraphTransformObserver.__pass_count}_{self.passname}_input_graph.dot",
)
)
output_dot_graph = FxGraphDrawer(
self.gm,
self.passname,
ignore_getattr=True,
ignore_parameters_and_buffers=True,
).get_dot_graph()
for e in output_dot_graph.get_node_list():
if e.get_name() in self.created_nodes:
e.obj_dict["attributes"]["fillcolor"] = "yellow"
else:
e.obj_dict["attributes"]["fillcolor"] = "grey"
output_dot_graph.write(
os.path.join(
self.log_url,
f"pass_{GraphTransformObserver.__pass_count}_{self.passname}_output_graph.dot",
)
)
def get_node_creation_hook(self):
# We have to return a function instead of using a class method directly
# to avoid max recursion issue when deepcopy a graph module within the context manager.
def on_node_creation(node):
self.created_nodes.add(node.name)
self.name_to_node[node.name] = node
source = NodeSource(None, self.passname, NodeSourceAction.CREATE)
if "from_node" not in node.meta:
node.meta["from_node"] = [source]
else:
node.meta["from_node"].append(source)
return on_node_creation
def get_node_erase_hook(self):
def on_node_erase(node):
self.erased_nodes.add(node.name)
self.name_to_node.pop(node.name, None)
return on_node_erase
def get_node_replace_hook(self):
def on_node_replace(old: Node, new: str, user: Node):
# Update node meta when replacing old node with new node
new_node = self.name_to_node.get(new, None)
if not new_node:
return
assert isinstance(new_node, Node)
# replace hook is called once for each user of old
# this avoids adding duplicated source nodes
added_nodes = {s.name for s in new_node.meta.get("from_node", [])}
if old.name in added_nodes:
return
action = [NodeSourceAction.REPLACE]
if new_node.name in self.created_nodes:
action.append(NodeSourceAction.CREATE)
def created_this_pass(source):
return source.pass_name == self.passname and source.action == [
NodeSourceAction.CREATE
]
# remove redundant source added on node creation
new_from_node = new_node.meta.get("from_node", [])
new_from_node = [
source for source in new_from_node if not created_this_pass(source)
]
# add new source
new_node_source = NodeSource(old, self.passname, action)
new_from_node.append(new_node_source)
new_node.meta["from_node"] = new_from_node
return on_node_replace
def get_deepcopy_hook(self):
def on_deepcopy(gm):
self.copied_gms.append(gm)
return on_deepcopy