mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Dynamo] add debug logging for graph region expansion (#141382)
This PR adds debug logging for the region expansion algorithm. Pull Request resolved: https://github.com/pytorch/pytorch/pull/141382 Approved by: https://github.com/williamwen42 ghstack dependencies: #141381
This commit is contained in:
committed by
PyTorch MergeBot
parent
96c36a6947
commit
49e4307686
@ -3,26 +3,11 @@ import contextlib
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
from torch._dynamo.test_case import TestCase
|
||||
from torch._dynamo.testing import extract_graph_and_tracker
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
|
||||
def extract_graph_and_tracker(fn, *args, **kwargs):
|
||||
gm = None
|
||||
region_tracker = None
|
||||
|
||||
def extract_graph_backend(_gm, *args, **kwargs):
|
||||
nonlocal gm
|
||||
nonlocal region_tracker
|
||||
gm = _gm
|
||||
region_tracker = InstructionTranslator.current_tx().output.region_tracker
|
||||
return _gm
|
||||
|
||||
torch.compile(backend=extract_graph_backend, fullgraph=True)(fn)(*args, **kwargs)
|
||||
return gm.graph, region_tracker
|
||||
|
||||
|
||||
def get_nodes_by_name(graph, names):
|
||||
nodes = []
|
||||
for node in graph.nodes:
|
||||
|
@ -10,7 +10,11 @@ import torch
|
||||
import torch._dynamo.test_case
|
||||
import torch._dynamo.testing
|
||||
import torch.distributed as dist
|
||||
from torch._dynamo.testing import empty_line_normalizer, skipIfNotPy311
|
||||
from torch._dynamo.testing import (
|
||||
empty_line_normalizer,
|
||||
extract_graph_and_tracker,
|
||||
skipIfNotPy311,
|
||||
)
|
||||
from torch._dynamo.trace_rules import _as_posix_path
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.testing._internal.common_utils import (
|
||||
@ -731,6 +735,29 @@ TRACE FX call mul from test_logging.py:N in fn (LoggingTests.test_trace_call_pre
|
||||
self.assertGreater(len(records), 0)
|
||||
self.assertLess(len(records), 3)
|
||||
|
||||
@make_logging_test(graph_region_expansion=True)
|
||||
def test_graph_region_expansion(self, records):
|
||||
with torch._dynamo.config.patch("track_nodes_for_deduplication", True):
|
||||
|
||||
def inner_fn(x, y):
|
||||
x0 = x + 1
|
||||
y0 = y + 2
|
||||
z = x0.sum() + y0.sum()
|
||||
return z
|
||||
|
||||
def fn(x, y):
|
||||
o0 = inner_fn(x, y)
|
||||
o1 = torch.sin(o0)
|
||||
o2 = inner_fn(x, o1)
|
||||
o3 = inner_fn(x, y)
|
||||
return o2 * o3 * o3
|
||||
|
||||
graph, tracker = extract_graph_and_tracker(
|
||||
fn, torch.randn(10, 10), torch.randn(10, 10)
|
||||
)
|
||||
tracker.get_identical_regions(graph)
|
||||
self.assertGreater(len(records), 0)
|
||||
|
||||
@skipIfTorchDynamo("too slow")
|
||||
@make_logging_test(**torch._logging.DEFAULT_LOGGING)
|
||||
def test_default_logging(self, records):
|
||||
@ -864,6 +891,7 @@ exclusions = {
|
||||
"cudagraph_static_inputs",
|
||||
"benchmarking",
|
||||
"loop_ordering",
|
||||
"graph_region_expansion",
|
||||
}
|
||||
for name in torch._logging._internal.log_registry.artifact_names:
|
||||
if name not in exclusions:
|
||||
|
@ -18,6 +18,7 @@ from typing import (
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
import torch._logging
|
||||
import torch.fx
|
||||
from torch._subclasses.fake_tensor import FakeTensor
|
||||
from torch.utils._pytree import tree_flatten
|
||||
@ -36,6 +37,13 @@ IdenticalNodes = List[Node]
|
||||
GlobalStateKey = Tuple[bool, bool, int, bool, bool, torch.dtype, bool, bool, bool, bool]
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
graph_expansion_log = torch._logging.getArtifactLogger(
|
||||
__name__, "graph_region_expansion"
|
||||
)
|
||||
|
||||
|
||||
def debug_log(msg: str, *args) -> None: # type: ignore[no-untyped-def]
|
||||
graph_expansion_log.debug(msg, *args)
|
||||
|
||||
|
||||
def _extract_tensor_metadata_for_node_hash(
|
||||
@ -278,6 +286,9 @@ def fully_expand_region_group(
|
||||
seen_nodes: Set[Node],
|
||||
is_identical_fn: Callable[[Node, Node], bool],
|
||||
) -> None:
|
||||
debug_log("--------------------------------------------------")
|
||||
debug_log("expanding new region group: %s", regions)
|
||||
|
||||
# All regions should start with 1 node
|
||||
assert all(len(region) == 1 for region in regions)
|
||||
region_iters = []
|
||||
@ -306,7 +317,12 @@ def fully_expand_region_group(
|
||||
for region_it in region_iters[1:]:
|
||||
node = region_it.next()
|
||||
|
||||
debug_log("--------------------")
|
||||
debug_log("considering adding: %s, cur_node: %s", node, current_node)
|
||||
debug_log("previously claimed nodes: %s", node in seen_nodes)
|
||||
debug_log("%s", seen_nodes)
|
||||
if node:
|
||||
debug_log("is_identical: %s", is_identical_fn(node, current_node))
|
||||
add_node &= (
|
||||
node not in seen_nodes
|
||||
and node not in nodes_to_add_set
|
||||
@ -317,9 +333,13 @@ def fully_expand_region_group(
|
||||
else:
|
||||
add_node = False
|
||||
|
||||
debug_log("--------------------")
|
||||
|
||||
if add_node:
|
||||
for region, region_it, node in zip(regions, region_iters, nodes_to_add):
|
||||
region.append(node)
|
||||
debug_log("adding %s's children", node)
|
||||
debug_log("%s %s", node.args, list(node.kwargs.items()))
|
||||
region_it.add_children(node)
|
||||
seen_nodes.add(node)
|
||||
|
||||
@ -328,3 +348,6 @@ def fully_expand_region_group(
|
||||
# Ensure regions are sorted in topological order
|
||||
for region in regions:
|
||||
region.reverse()
|
||||
|
||||
debug_log("end expand new region group: %s", regions)
|
||||
debug_log("--------------------------------------------------")
|
||||
|
@ -62,6 +62,23 @@ def remove_optimized_module_prefix(name: str) -> str:
|
||||
return re.sub(r"^_orig_mod[.]", "", name)
|
||||
|
||||
|
||||
def extract_graph_and_tracker(fn, *args, **kwargs): # type: ignore[no-untyped-def]
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
|
||||
gm = None
|
||||
region_tracker = None
|
||||
|
||||
def extract_graph_backend(_gm, *args, **kwargs): # type: ignore[no-untyped-def]
|
||||
nonlocal gm
|
||||
nonlocal region_tracker
|
||||
gm = _gm
|
||||
region_tracker = InstructionTranslator.current_tx().output.region_tracker
|
||||
return _gm
|
||||
|
||||
torch.compile(backend=extract_graph_backend, fullgraph=True)(fn)(*args, **kwargs)
|
||||
return gm.graph, region_tracker # type: ignore[union-attr]
|
||||
|
||||
|
||||
def collect_results(
|
||||
model: torch.nn.Module, prediction: Any, loss: Any, example_inputs: Any
|
||||
) -> List[Any]:
|
||||
|
@ -240,6 +240,7 @@ def set_logs(
|
||||
compiled_autograd_verbose: bool = False,
|
||||
cudagraph_static_inputs: bool = False,
|
||||
benchmarking: bool = False,
|
||||
graph_region_expansion: bool = False,
|
||||
):
|
||||
"""
|
||||
Sets the log level for individual components and toggles individual log
|
||||
@ -416,6 +417,9 @@ def set_logs(
|
||||
cudagraph_static_inputs (:class:`bool`):
|
||||
Whether to emit debug info for cudagraph static input detection. Default: ``False``
|
||||
|
||||
graph_region_expansion (:class:`bool`):
|
||||
Whether to emit the detailed steps of the duplicate graph region tracker expansion algorithm. Default: ``False``
|
||||
|
||||
|
||||
Example::
|
||||
|
||||
@ -514,6 +518,7 @@ def set_logs(
|
||||
compiled_autograd_verbose=compiled_autograd_verbose,
|
||||
cudagraph_static_inputs=cudagraph_static_inputs,
|
||||
benchmarking=benchmarking,
|
||||
graph_region_expansion=graph_region_expansion,
|
||||
)
|
||||
|
||||
|
||||
|
@ -191,5 +191,10 @@ register_artifact(
|
||||
"Detailed Inductor benchmarking information.",
|
||||
off_by_default=True,
|
||||
)
|
||||
register_artifact(
|
||||
"graph_region_expansion",
|
||||
"Logs detailed steps of the duplicate graph region tracker expansion algorithm",
|
||||
off_by_default=True,
|
||||
)
|
||||
|
||||
register_artifact("custom_format_test_artifact", "Testing only", log_format="")
|
||||
|
Reference in New Issue
Block a user