From e4d6c56ffb3d680d3874f0dd01907aee7ed2d3c5 Mon Sep 17 00:00:00 2001 From: Yiming Zhou Date: Sat, 18 Oct 2025 03:48:18 +0000 Subject: [PATCH] Improve dynamo graph capture stack trace for custom ops (#165693) For a custom op ``` @torch.library.custom_op("my_lib::foo", mutates_args={}) def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return x + y ``` ppl could call `torch.ops.my_lib.foo()` or directly call `foo()` in the `forward` of an `nn.Module` These two calling conventions will lead to the same node in the output graph, but different stack traces. When directly calling `foo()`, the displayed stack_trace in the graph will be ``` # File: .../pytorch/torch/_library/custom_ops.py:687 in __call__, code: return self._opoverload(*args, **kwargs) ``` This is not useful so we filter it out. ``` python test/functorch/test_aot_joint_with_descriptors.py -k test_custom_op_stack_trace ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/165693 Approved by: https://github.com/SherlockNoMad, https://github.com/williamwen42 --- .../test_aot_joint_with_descriptors.py | 46 ++++++++++++++++++- torch/_dynamo/output_graph.py | 12 ++++- 2 files changed, 55 insertions(+), 3 deletions(-) diff --git a/test/functorch/test_aot_joint_with_descriptors.py b/test/functorch/test_aot_joint_with_descriptors.py index d797b36748d0..24d9042bc9c9 100644 --- a/test/functorch/test_aot_joint_with_descriptors.py +++ b/test/functorch/test_aot_joint_with_descriptors.py @@ -38,7 +38,12 @@ from torch._functorch.aot_autograd import ( ) from torch._guards import tracing, TracingContext from torch.nn.attention.flex_attention import create_block_mask, flex_attention -from torch.testing._internal.common_utils import requires_cuda, run_tests, TestCase +from torch.testing._internal.common_utils import ( + requires_cuda, + run_tests, + skipIfCrossRef, + TestCase, +) def graph_capture(model, inputs, with_export): @@ -962,6 +967,45 @@ class inner_f(torch.nn.Module): ('call_function', 't_3', {'pp_stage': 0})""", ) + @skipIfCrossRef + def test_custom_op_stack_trace(self): + @torch.library.custom_op("my_lib::foo", mutates_args={}) + def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return x + y + + @foo.register_fake + def foo_fake_impl(x, y): + return torch.empty_like(x) + + def foo_setup_context(ctx, inputs, output): + pass + + def foo_backward(ctx, grad_output): + return grad_output, grad_output + + foo.register_autograd(foo_backward, setup_context=foo_setup_context) + + class CustomOpModule(torch.nn.Module): + def forward(self, x, y): + return foo(x, y) + + model = CustomOpModule() + inputs = (torch.randn(4, 3), torch.randn(4, 3)) + + gm = graph_capture(model, inputs, with_export=True) + + foo_node = None + for node in gm.graph.nodes: + if node.op == "call_function" and node.name == "foo": + foo_node = node + break + + self.assertTrue(foo_node is not None) + self.assertTrue("return foo(x, y)" in foo_node.meta.get("stack_trace", None)) + self.assertTrue("return foo(x, y)" in gm.print_readable(print_output=False)) + self.assertFalse("self._opoverload" in foo_node.meta.get("stack_trace", None)) + self.assertFalse("self._opoverload" in gm.print_readable(print_output=False)) + if __name__ == "__main__": run_tests() diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index feeeed32b9d1..9bce964c3f1a 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -67,6 +67,7 @@ from torch.fx.experimental.symbolic_shapes import ( is_symbolic, ShapeEnv, Specialization, + uninteresting_files, ) from torch.fx.node import Target from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts @@ -3170,11 +3171,18 @@ class SubgraphTracer(fx.Tracer): if not tx.is_co_filename_from_nn_modules(): frame_summaries.append(tx.frame_summary()) tx = getattr(tx, "parent", None) + + filtered_frame_summaries = [ + frame + for frame in frame_summaries + if frame.filename not in uninteresting_files() + ] + # Reverse the frame_summaries, such that the innermost frame is at the last - frame_summaries.reverse() + filtered_frame_summaries.reverse() # official from_list stub doesn't have new-style type - msgs = traceback.StackSummary.from_list(frame_summaries).format() + msgs = traceback.StackSummary.from_list(filtered_frame_summaries).format() rv.node.stack_trace = "".join(msgs) if (