From c6d1038aaae6c53a30657e12ace493f74482d6ac Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Thu, 27 Feb 2025 07:33:55 -0800 Subject: [PATCH] only print GraphModule during fx.Interpreter errors if valid (#148090) Came up in https://www.internalfb.com/diff/D69057074?dst_version_fbid=970771615000938&transaction_fbid=1723357345264461 - we need to make sure the GraphModule is valid before calling `print_readable` on it Pull Request resolved: https://github.com/pytorch/pytorch/pull/148090 Approved by: https://github.com/jamesjwu, https://github.com/zou3519 ghstack dependencies: #147749 --- torch/fx/interpreter.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/torch/fx/interpreter.py b/torch/fx/interpreter.py index d23ecfd9510c..e8b348ce1ca9 100644 --- a/torch/fx/interpreter.py +++ b/torch/fx/interpreter.py @@ -173,7 +173,11 @@ class Interpreter: if self.extra_traceback: msg = f"While executing {node.format_node()}" msg = f"{e.args[0]}\n\n{msg}" if e.args else str(msg) - if isinstance(self.module, GraphModule): + if ( + isinstance(self.module, GraphModule) + and self.module.graph is not None + and isinstance(self.module.graph, torch.fx.Graph) + ): msg += f"\nGraphModule: {self.module.print_readable(print_output=False, include_stride=True)}\n" msg += f"\nOriginal traceback:\n{node.stack_trace}" e.args = (msg,) + e.args[1:]