[Dynamo] add debug logging for graph region expansion (#141382)

This PR adds debug logging for the region expansion algorithm.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141382
Approved by: https://github.com/williamwen42
ghstack dependencies: #141381
This commit is contained in:
Michael Lazos
2024-12-10 11:55:57 -08:00
committed by PyTorch MergeBot
parent 96c36a6947
commit 49e4307686
6 changed files with 80 additions and 17 deletions

View File

@ -3,26 +3,11 @@ import contextlib
import torch
import torch.fx
from torch._dynamo.symbolic_convert import InstructionTranslator
from torch._dynamo.test_case import TestCase
from torch._dynamo.testing import extract_graph_and_tracker
from torch.utils._pytree import tree_map
def extract_graph_and_tracker(fn, *args, **kwargs):
gm = None
region_tracker = None
def extract_graph_backend(_gm, *args, **kwargs):
nonlocal gm
nonlocal region_tracker
gm = _gm
region_tracker = InstructionTranslator.current_tx().output.region_tracker
return _gm
torch.compile(backend=extract_graph_backend, fullgraph=True)(fn)(*args, **kwargs)
return gm.graph, region_tracker
def get_nodes_by_name(graph, names):
nodes = []
for node in graph.nodes:

View File

@ -10,7 +10,11 @@ import torch
import torch._dynamo.test_case
import torch._dynamo.testing
import torch.distributed as dist
from torch._dynamo.testing import empty_line_normalizer, skipIfNotPy311
from torch._dynamo.testing import (
empty_line_normalizer,
extract_graph_and_tracker,
skipIfNotPy311,
)
from torch._dynamo.trace_rules import _as_posix_path
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing._internal.common_utils import (
@ -731,6 +735,29 @@ TRACE FX call mul from test_logging.py:N in fn (LoggingTests.test_trace_call_pre
self.assertGreater(len(records), 0)
self.assertLess(len(records), 3)
@make_logging_test(graph_region_expansion=True)
def test_graph_region_expansion(self, records):
with torch._dynamo.config.patch("track_nodes_for_deduplication", True):
def inner_fn(x, y):
x0 = x + 1
y0 = y + 2
z = x0.sum() + y0.sum()
return z
def fn(x, y):
o0 = inner_fn(x, y)
o1 = torch.sin(o0)
o2 = inner_fn(x, o1)
o3 = inner_fn(x, y)
return o2 * o3 * o3
graph, tracker = extract_graph_and_tracker(
fn, torch.randn(10, 10), torch.randn(10, 10)
)
tracker.get_identical_regions(graph)
self.assertGreater(len(records), 0)
@skipIfTorchDynamo("too slow")
@make_logging_test(**torch._logging.DEFAULT_LOGGING)
def test_default_logging(self, records):
@ -864,6 +891,7 @@ exclusions = {
"cudagraph_static_inputs",
"benchmarking",
"loop_ordering",
"graph_region_expansion",
}
for name in torch._logging._internal.log_registry.artifact_names:
if name not in exclusions:

View File

@ -18,6 +18,7 @@ from typing import (
TypeVar,
)
import torch._logging
import torch.fx
from torch._subclasses.fake_tensor import FakeTensor
from torch.utils._pytree import tree_flatten
@ -36,6 +37,13 @@ IdenticalNodes = List[Node]
GlobalStateKey = Tuple[bool, bool, int, bool, bool, torch.dtype, bool, bool, bool, bool]
log = logging.getLogger(__name__)
graph_expansion_log = torch._logging.getArtifactLogger(
__name__, "graph_region_expansion"
)
def debug_log(msg: str, *args) -> None: # type: ignore[no-untyped-def]
graph_expansion_log.debug(msg, *args)
def _extract_tensor_metadata_for_node_hash(
@ -278,6 +286,9 @@ def fully_expand_region_group(
seen_nodes: Set[Node],
is_identical_fn: Callable[[Node, Node], bool],
) -> None:
debug_log("--------------------------------------------------")
debug_log("expanding new region group: %s", regions)
# All regions should start with 1 node
assert all(len(region) == 1 for region in regions)
region_iters = []
@ -306,7 +317,12 @@ def fully_expand_region_group(
for region_it in region_iters[1:]:
node = region_it.next()
debug_log("--------------------")
debug_log("considering adding: %s, cur_node: %s", node, current_node)
debug_log("previously claimed nodes: %s", node in seen_nodes)
debug_log("%s", seen_nodes)
if node:
debug_log("is_identical: %s", is_identical_fn(node, current_node))
add_node &= (
node not in seen_nodes
and node not in nodes_to_add_set
@ -317,9 +333,13 @@ def fully_expand_region_group(
else:
add_node = False
debug_log("--------------------")
if add_node:
for region, region_it, node in zip(regions, region_iters, nodes_to_add):
region.append(node)
debug_log("adding %s's children", node)
debug_log("%s %s", node.args, list(node.kwargs.items()))
region_it.add_children(node)
seen_nodes.add(node)
@ -328,3 +348,6 @@ def fully_expand_region_group(
# Ensure regions are sorted in topological order
for region in regions:
region.reverse()
debug_log("end expand new region group: %s", regions)
debug_log("--------------------------------------------------")

View File

@ -62,6 +62,23 @@ def remove_optimized_module_prefix(name: str) -> str:
return re.sub(r"^_orig_mod[.]", "", name)
def extract_graph_and_tracker(fn, *args, **kwargs): # type: ignore[no-untyped-def]
from torch._dynamo.symbolic_convert import InstructionTranslator
gm = None
region_tracker = None
def extract_graph_backend(_gm, *args, **kwargs): # type: ignore[no-untyped-def]
nonlocal gm
nonlocal region_tracker
gm = _gm
region_tracker = InstructionTranslator.current_tx().output.region_tracker
return _gm
torch.compile(backend=extract_graph_backend, fullgraph=True)(fn)(*args, **kwargs)
return gm.graph, region_tracker # type: ignore[union-attr]
def collect_results(
model: torch.nn.Module, prediction: Any, loss: Any, example_inputs: Any
) -> List[Any]:

View File

@ -240,6 +240,7 @@ def set_logs(
compiled_autograd_verbose: bool = False,
cudagraph_static_inputs: bool = False,
benchmarking: bool = False,
graph_region_expansion: bool = False,
):
"""
Sets the log level for individual components and toggles individual log
@ -416,6 +417,9 @@ def set_logs(
cudagraph_static_inputs (:class:`bool`):
Whether to emit debug info for cudagraph static input detection. Default: ``False``
graph_region_expansion (:class:`bool`):
Whether to emit the detailed steps of the duplicate graph region tracker expansion algorithm. Default: ``False``
Example::
@ -514,6 +518,7 @@ def set_logs(
compiled_autograd_verbose=compiled_autograd_verbose,
cudagraph_static_inputs=cudagraph_static_inputs,
benchmarking=benchmarking,
graph_region_expansion=graph_region_expansion,
)

View File

@ -191,5 +191,10 @@ register_artifact(
"Detailed Inductor benchmarking information.",
off_by_default=True,
)
register_artifact(
"graph_region_expansion",
"Logs detailed steps of the duplicate graph region tracker expansion algorithm",
off_by_default=True,
)
register_artifact("custom_format_test_artifact", "Testing only", log_format="")