mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
This is follow-up of #164653 to continue applying `UP035` fixes. The purpose is to finally enable this rule. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165214 Approved by: https://github.com/ezyang
230 lines
7.6 KiB
Python
230 lines
7.6 KiB
Python
# mypy: allow-untyped-defs
|
|
import os
|
|
from collections.abc import Callable
|
|
from typing import Optional, TypeVar
|
|
|
|
from torch.fx import Graph, Node
|
|
from torch.fx._compatibility import compatibility
|
|
from torch.fx.graph_module import GraphModule
|
|
from torch.fx.traceback import NodeSource, NodeSourceAction
|
|
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
|
from .graph_drawer import FxGraphDrawer
|
|
|
|
|
|
__all__ = ["GraphTransformObserver"]
|
|
|
|
|
|
@compatibility(is_backward_compatible=False)
|
|
class GraphTransformObserver:
|
|
__pass_count = 0
|
|
|
|
def __init__(
|
|
self,
|
|
gm: GraphModule,
|
|
passname: str,
|
|
subsystem: Optional[str] = None,
|
|
log_url: Optional[str] = None,
|
|
):
|
|
"""
|
|
log_url is inferred to be torch._inductor.config.trace.log_url_for_graph_xform unless otherwise specified
|
|
"""
|
|
from torch._inductor import config as inductor_config
|
|
|
|
self.gm = gm
|
|
self.passname = passname
|
|
self.subsystem = subsystem
|
|
|
|
if log_url is None:
|
|
log_url = inductor_config.trace.log_url_for_graph_xform
|
|
|
|
self.log_url = log_url
|
|
|
|
self.active = (
|
|
self.log_url is not None
|
|
or inductor_config.trace.provenance_tracking_level == 1
|
|
)
|
|
|
|
if self.active:
|
|
self.erased_nodes: set[str] = set()
|
|
self.created_nodes: set[str] = set()
|
|
self.name_to_node: dict[str, Node] = {}
|
|
# record graph modules deepcopied from self.gm, so we can remove hooks on them when exiting the context
|
|
self.copied_gms: list[GraphModule] = []
|
|
|
|
self._node_creation_hook = self.get_node_creation_hook()
|
|
self._node_erase_hook = self.get_node_erase_hook()
|
|
self._node_replace_hook = self.get_node_replace_hook()
|
|
self._deepcopy_hook = self.get_deepcopy_hook()
|
|
|
|
# If log_url is None, we don't log anything
|
|
if self.log_url is None:
|
|
return
|
|
GraphTransformObserver.__pass_count += 1
|
|
|
|
self.input_dot_graph = FxGraphDrawer(
|
|
self.gm,
|
|
self.passname,
|
|
ignore_getattr=True,
|
|
ignore_parameters_and_buffers=True,
|
|
).get_dot_graph()
|
|
|
|
@classmethod
|
|
def get_current_pass_count(cls):
|
|
return cls.__pass_count
|
|
|
|
def apply_gm_pass(self, pass_fn: Callable[[GraphModule], T]) -> Optional[T]:
|
|
with self:
|
|
if not self._check_disable_pass():
|
|
return pass_fn(self.gm)
|
|
|
|
return None
|
|
|
|
def apply_graph_pass(self, pass_fn: Callable[[Graph], T]) -> Optional[T]:
|
|
with self:
|
|
if not self._check_disable_pass():
|
|
return pass_fn(self.gm.graph)
|
|
|
|
return None
|
|
|
|
def _check_disable_pass(self):
|
|
if self.subsystem is None:
|
|
return False
|
|
|
|
debug_info = lambda: self.passname # noqa: E731
|
|
from torch._inductor.compiler_bisector import CompilerBisector
|
|
|
|
return CompilerBisector.disable_subsystem(
|
|
"inductor", self.subsystem, debug_info
|
|
)
|
|
|
|
def __enter__(self):
|
|
if not self.active:
|
|
return self
|
|
self.gm._register_create_node_hook(self._node_creation_hook)
|
|
self.gm._register_erase_node_hook(self._node_erase_hook)
|
|
self.gm._register_replace_node_hook(self._node_replace_hook)
|
|
self.gm._register_deepcopy_hook(self._deepcopy_hook)
|
|
|
|
self.erased_nodes.clear()
|
|
self.created_nodes.clear()
|
|
self.name_to_node.clear()
|
|
self.copied_gms.clear()
|
|
|
|
for node in self.gm.graph.nodes:
|
|
self.name_to_node[node.name] = node
|
|
|
|
return self
|
|
|
|
def __exit__(self, type, value, tb):
|
|
if not self.active:
|
|
return
|
|
for gm in self.copied_gms + [self.gm]:
|
|
gm._unregister_create_node_hook(self._node_creation_hook)
|
|
gm._unregister_erase_node_hook(self._node_erase_hook)
|
|
gm._unregister_replace_node_hook(self._node_replace_hook)
|
|
gm._unregister_deepcopy_hook(self._deepcopy_hook)
|
|
|
|
if self.log_url is None:
|
|
return
|
|
|
|
if len(self.created_nodes) > 0 or len(self.erased_nodes) > 0:
|
|
for e in self.input_dot_graph.get_node_list():
|
|
if e.get_name() in self.erased_nodes:
|
|
e.obj_dict["attributes"]["fillcolor"] = "yellow"
|
|
else:
|
|
e.obj_dict["attributes"]["fillcolor"] = "grey"
|
|
assert self.log_url is not None
|
|
self.input_dot_graph.write(
|
|
os.path.join(
|
|
self.log_url,
|
|
f"pass_{GraphTransformObserver.__pass_count}_{self.passname}_input_graph.dot",
|
|
)
|
|
)
|
|
|
|
output_dot_graph = FxGraphDrawer(
|
|
self.gm,
|
|
self.passname,
|
|
ignore_getattr=True,
|
|
ignore_parameters_and_buffers=True,
|
|
).get_dot_graph()
|
|
for e in output_dot_graph.get_node_list():
|
|
if e.get_name() in self.created_nodes:
|
|
e.obj_dict["attributes"]["fillcolor"] = "yellow"
|
|
else:
|
|
e.obj_dict["attributes"]["fillcolor"] = "grey"
|
|
output_dot_graph.write(
|
|
os.path.join(
|
|
self.log_url,
|
|
f"pass_{GraphTransformObserver.__pass_count}_{self.passname}_output_graph.dot",
|
|
)
|
|
)
|
|
|
|
def get_node_creation_hook(self):
|
|
# We have to return a function instead of using a class method directly
|
|
# to avoid max recursion issue when deepcopy a graph module within the context manager.
|
|
def on_node_creation(node):
|
|
self.created_nodes.add(node.name)
|
|
self.name_to_node[node.name] = node
|
|
source = NodeSource(None, self.passname, NodeSourceAction.CREATE)
|
|
if "from_node" not in node.meta:
|
|
node.meta["from_node"] = [source]
|
|
else:
|
|
node.meta["from_node"].append(source)
|
|
|
|
return on_node_creation
|
|
|
|
def get_node_erase_hook(self):
|
|
def on_node_erase(node):
|
|
self.erased_nodes.add(node.name)
|
|
self.name_to_node.pop(node.name, None)
|
|
|
|
return on_node_erase
|
|
|
|
def get_node_replace_hook(self):
|
|
def on_node_replace(old: Node, new: str, user: Node):
|
|
# Update node meta when replacing old node with new node
|
|
new_node = self.name_to_node.get(new, None)
|
|
|
|
if not new_node:
|
|
return
|
|
|
|
assert isinstance(new_node, Node)
|
|
|
|
# replace hook is called once for each user of old
|
|
# this avoids adding duplicated source nodes
|
|
added_nodes = {s.name for s in new_node.meta.get("from_node", [])}
|
|
if old.name in added_nodes:
|
|
return
|
|
|
|
action = [NodeSourceAction.REPLACE]
|
|
if new_node.name in self.created_nodes:
|
|
action.append(NodeSourceAction.CREATE)
|
|
|
|
def created_this_pass(source):
|
|
return source.pass_name == self.passname and source.action == [
|
|
NodeSourceAction.CREATE
|
|
]
|
|
|
|
# remove redundant source added on node creation
|
|
new_from_node = new_node.meta.get("from_node", [])
|
|
new_from_node = [
|
|
source for source in new_from_node if not created_this_pass(source)
|
|
]
|
|
|
|
# add new source
|
|
new_node_source = NodeSource(old, self.passname, action)
|
|
new_from_node.append(new_node_source)
|
|
new_node.meta["from_node"] = new_from_node
|
|
|
|
return on_node_replace
|
|
|
|
def get_deepcopy_hook(self):
|
|
def on_deepcopy(gm):
|
|
self.copied_gms.append(gm)
|
|
|
|
return on_deepcopy
|