Compare commits

...

1 Commits

Author SHA1 Message Date
f4c50d4713 improve dynamo graph capture stack trace for custom ops 2025-10-17 13:44:34 -07:00
2 changed files with 55 additions and 3 deletions

View File

@ -38,7 +38,12 @@ from torch._functorch.aot_autograd import (
)
from torch._guards import tracing, TracingContext
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
from torch.testing._internal.common_utils import requires_cuda, run_tests, TestCase
from torch.testing._internal.common_utils import (
requires_cuda,
run_tests,
skipIfCrossRef,
TestCase,
)
def graph_capture(model, inputs, with_export):
@ -962,6 +967,45 @@ class inner_f(torch.nn.Module):
('call_function', 't_3', {'pp_stage': 0})""",
)
@skipIfCrossRef
def test_custom_op_stack_trace(self):
@torch.library.custom_op("my_lib::foo", mutates_args={})
def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y
@foo.register_fake
def foo_fake_impl(x, y):
return torch.empty_like(x)
def foo_setup_context(ctx, inputs, output):
pass
def foo_backward(ctx, grad_output):
return grad_output, grad_output
foo.register_autograd(foo_backward, setup_context=foo_setup_context)
class CustomOpModule(torch.nn.Module):
def forward(self, x, y):
return foo(x, y)
model = CustomOpModule()
inputs = (torch.randn(4, 3), torch.randn(4, 3))
gm = graph_capture(model, inputs, with_export=True)
foo_node = None
for node in gm.graph.nodes:
if node.op == "call_function" and node.name == "foo":
foo_node = node
break
self.assertTrue(foo_node is not None)
self.assertTrue("return foo(x, y)" in foo_node.meta.get("stack_trace", None))
self.assertTrue("return foo(x, y)" in gm.print_readable(print_output=False))
self.assertFalse("self._opoverload" in foo_node.meta.get("stack_trace", None))
self.assertFalse("self._opoverload" in gm.print_readable(print_output=False))
if __name__ == "__main__":
run_tests()

View File

@ -67,6 +67,7 @@ from torch.fx.experimental.symbolic_shapes import (
is_symbolic,
ShapeEnv,
Specialization,
uninteresting_files,
)
from torch.fx.node import Target
from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts
@ -3170,11 +3171,18 @@ class SubgraphTracer(fx.Tracer):
if not tx.is_co_filename_from_nn_modules():
frame_summaries.append(tx.frame_summary())
tx = getattr(tx, "parent", None)
filtered_frame_summaries = [
frame
for frame in frame_summaries
if frame.filename not in uninteresting_files()
]
# Reverse the frame_summaries, such that the innermost frame is at the last
frame_summaries.reverse()
filtered_frame_summaries.reverse()
# official from_list stub doesn't have new-style type
msgs = traceback.StackSummary.from_list(frame_summaries).format()
msgs = traceback.StackSummary.from_list(filtered_frame_summaries).format()
rv.node.stack_trace = "".join(msgs)
if (