mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Dynamo] add debug logging for graph region expansion (#141382)
This PR adds debug logging for the region expansion algorithm. Pull Request resolved: https://github.com/pytorch/pytorch/pull/141382 Approved by: https://github.com/williamwen42 ghstack dependencies: #141381
This commit is contained in:
committed by
PyTorch MergeBot
parent
96c36a6947
commit
49e4307686
@ -62,6 +62,23 @@ def remove_optimized_module_prefix(name: str) -> str:
|
||||
return re.sub(r"^_orig_mod[.]", "", name)
|
||||
|
||||
|
||||
def extract_graph_and_tracker(fn, *args, **kwargs): # type: ignore[no-untyped-def]
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
|
||||
gm = None
|
||||
region_tracker = None
|
||||
|
||||
def extract_graph_backend(_gm, *args, **kwargs): # type: ignore[no-untyped-def]
|
||||
nonlocal gm
|
||||
nonlocal region_tracker
|
||||
gm = _gm
|
||||
region_tracker = InstructionTranslator.current_tx().output.region_tracker
|
||||
return _gm
|
||||
|
||||
torch.compile(backend=extract_graph_backend, fullgraph=True)(fn)(*args, **kwargs)
|
||||
return gm.graph, region_tracker # type: ignore[union-attr]
|
||||
|
||||
|
||||
def collect_results(
|
||||
model: torch.nn.Module, prediction: Any, loss: Any, example_inputs: Any
|
||||
) -> List[Any]:
|
||||
|
Reference in New Issue
Block a user