mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Improve dynamo graph capture stack trace for custom ops (#165693)
For a custom op ``` @torch.library.custom_op("my_lib::foo", mutates_args={}) def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return x + y ``` ppl could call `torch.ops.my_lib.foo()` or directly call `foo()` in the `forward` of an `nn.Module` These two calling conventions will lead to the same node in the output graph, but different stack traces. When directly calling `foo()`, the displayed stack_trace in the graph will be ``` # File: .../pytorch/torch/_library/custom_ops.py:687 in __call__, code: return self._opoverload(*args, **kwargs) ``` This is not useful so we filter it out. ``` python test/functorch/test_aot_joint_with_descriptors.py -k test_custom_op_stack_trace ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/165693 Approved by: https://github.com/SherlockNoMad, https://github.com/williamwen42
This commit is contained in:
committed by
PyTorch MergeBot
parent
017d2985f3
commit
e4d6c56ffb
@ -38,7 +38,12 @@ from torch._functorch.aot_autograd import (
|
||||
)
|
||||
from torch._guards import tracing, TracingContext
|
||||
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
|
||||
from torch.testing._internal.common_utils import requires_cuda, run_tests, TestCase
|
||||
from torch.testing._internal.common_utils import (
|
||||
requires_cuda,
|
||||
run_tests,
|
||||
skipIfCrossRef,
|
||||
TestCase,
|
||||
)
|
||||
|
||||
|
||||
def graph_capture(model, inputs, with_export):
|
||||
@ -962,6 +967,45 @@ class inner_f(torch.nn.Module):
|
||||
('call_function', 't_3', {'pp_stage': 0})""",
|
||||
)
|
||||
|
||||
@skipIfCrossRef
|
||||
def test_custom_op_stack_trace(self):
|
||||
@torch.library.custom_op("my_lib::foo", mutates_args={})
|
||||
def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
return x + y
|
||||
|
||||
@foo.register_fake
|
||||
def foo_fake_impl(x, y):
|
||||
return torch.empty_like(x)
|
||||
|
||||
def foo_setup_context(ctx, inputs, output):
|
||||
pass
|
||||
|
||||
def foo_backward(ctx, grad_output):
|
||||
return grad_output, grad_output
|
||||
|
||||
foo.register_autograd(foo_backward, setup_context=foo_setup_context)
|
||||
|
||||
class CustomOpModule(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
return foo(x, y)
|
||||
|
||||
model = CustomOpModule()
|
||||
inputs = (torch.randn(4, 3), torch.randn(4, 3))
|
||||
|
||||
gm = graph_capture(model, inputs, with_export=True)
|
||||
|
||||
foo_node = None
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == "call_function" and node.name == "foo":
|
||||
foo_node = node
|
||||
break
|
||||
|
||||
self.assertTrue(foo_node is not None)
|
||||
self.assertTrue("return foo(x, y)" in foo_node.meta.get("stack_trace", None))
|
||||
self.assertTrue("return foo(x, y)" in gm.print_readable(print_output=False))
|
||||
self.assertFalse("self._opoverload" in foo_node.meta.get("stack_trace", None))
|
||||
self.assertFalse("self._opoverload" in gm.print_readable(print_output=False))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
@ -67,6 +67,7 @@ from torch.fx.experimental.symbolic_shapes import (
|
||||
is_symbolic,
|
||||
ShapeEnv,
|
||||
Specialization,
|
||||
uninteresting_files,
|
||||
)
|
||||
from torch.fx.node import Target
|
||||
from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts
|
||||
@ -3170,11 +3171,18 @@ class SubgraphTracer(fx.Tracer):
|
||||
if not tx.is_co_filename_from_nn_modules():
|
||||
frame_summaries.append(tx.frame_summary())
|
||||
tx = getattr(tx, "parent", None)
|
||||
|
||||
filtered_frame_summaries = [
|
||||
frame
|
||||
for frame in frame_summaries
|
||||
if frame.filename not in uninteresting_files()
|
||||
]
|
||||
|
||||
# Reverse the frame_summaries, such that the innermost frame is at the last
|
||||
frame_summaries.reverse()
|
||||
filtered_frame_summaries.reverse()
|
||||
|
||||
# official from_list stub doesn't have new-style type
|
||||
msgs = traceback.StackSummary.from_list(frame_summaries).format()
|
||||
msgs = traceback.StackSummary.from_list(filtered_frame_summaries).format()
|
||||
rv.node.stack_trace = "".join(msgs)
|
||||
|
||||
if (
|
||||
|
Reference in New Issue
Block a user