[Dynamo] add debug logging for graph region expansion (#141382)

This PR adds debug logging for the region expansion algorithm.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141382
Approved by: https://github.com/williamwen42
ghstack dependencies: #141381
This commit is contained in:
Michael Lazos
2024-12-10 11:55:57 -08:00
committed by PyTorch MergeBot
parent 96c36a6947
commit 49e4307686
6 changed files with 80 additions and 17 deletions

View File

@ -62,6 +62,23 @@ def remove_optimized_module_prefix(name: str) -> str:
return re.sub(r"^_orig_mod[.]", "", name)
def extract_graph_and_tracker(fn, *args, **kwargs): # type: ignore[no-untyped-def]
from torch._dynamo.symbolic_convert import InstructionTranslator
gm = None
region_tracker = None
def extract_graph_backend(_gm, *args, **kwargs): # type: ignore[no-untyped-def]
nonlocal gm
nonlocal region_tracker
gm = _gm
region_tracker = InstructionTranslator.current_tx().output.region_tracker
return _gm
torch.compile(backend=extract_graph_backend, fullgraph=True)(fn)(*args, **kwargs)
return gm.graph, region_tracker # type: ignore[union-attr]
def collect_results(
model: torch.nn.Module, prediction: Any, loss: Any, example_inputs: Any
) -> List[Any]: