only print GraphModule during fx.Interpreter errors if valid (#148090)

Came up in https://www.internalfb.com/diff/D69057074?dst_version_fbid=970771615000938&transaction_fbid=1723357345264461 - we need to make sure the GraphModule is valid before calling `print_readable` on it

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148090
Approved by: https://github.com/jamesjwu, https://github.com/zou3519
ghstack dependencies: #147749
This commit is contained in:
Brian Hirsh
2025-02-27 07:33:55 -08:00
committed by PyTorch MergeBot
parent 5a14ff8ace
commit c6d1038aaa

View File

@ -173,7 +173,11 @@ class Interpreter:
if self.extra_traceback:
msg = f"While executing {node.format_node()}"
msg = f"{e.args[0]}\n\n{msg}" if e.args else str(msg)
if isinstance(self.module, GraphModule):
if (
isinstance(self.module, GraphModule)
and self.module.graph is not None
and isinstance(self.module.graph, torch.fx.Graph)
):
msg += f"\nGraphModule: {self.module.print_readable(print_output=False, include_stride=True)}\n"
msg += f"\nOriginal traceback:\n{node.stack_trace}"
e.args = (msg,) + e.args[1:]