mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Improve dynamo graph capture stack trace for custom ops (#165693)
For a custom op ``` @torch.library.custom_op("my_lib::foo", mutates_args={}) def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return x + y ``` ppl could call `torch.ops.my_lib.foo()` or directly call `foo()` in the `forward` of an `nn.Module` These two calling conventions will lead to the same node in the output graph, but different stack traces. When directly calling `foo()`, the displayed stack_trace in the graph will be ``` # File: .../pytorch/torch/_library/custom_ops.py:687 in __call__, code: return self._opoverload(*args, **kwargs) ``` This is not useful so we filter it out. ``` python test/functorch/test_aot_joint_with_descriptors.py -k test_custom_op_stack_trace ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/165693 Approved by: https://github.com/SherlockNoMad, https://github.com/williamwen42
This commit is contained in:
committed by
PyTorch MergeBot
parent
017d2985f3
commit
e4d6c56ffb
@ -67,6 +67,7 @@ from torch.fx.experimental.symbolic_shapes import (
|
||||
is_symbolic,
|
||||
ShapeEnv,
|
||||
Specialization,
|
||||
uninteresting_files,
|
||||
)
|
||||
from torch.fx.node import Target
|
||||
from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts
|
||||
@ -3170,11 +3171,18 @@ class SubgraphTracer(fx.Tracer):
|
||||
if not tx.is_co_filename_from_nn_modules():
|
||||
frame_summaries.append(tx.frame_summary())
|
||||
tx = getattr(tx, "parent", None)
|
||||
|
||||
filtered_frame_summaries = [
|
||||
frame
|
||||
for frame in frame_summaries
|
||||
if frame.filename not in uninteresting_files()
|
||||
]
|
||||
|
||||
# Reverse the frame_summaries, such that the innermost frame is at the last
|
||||
frame_summaries.reverse()
|
||||
filtered_frame_summaries.reverse()
|
||||
|
||||
# official from_list stub doesn't have new-style type
|
||||
msgs = traceback.StackSummary.from_list(frame_summaries).format()
|
||||
msgs = traceback.StackSummary.from_list(filtered_frame_summaries).format()
|
||||
rv.node.stack_trace = "".join(msgs)
|
||||
|
||||
if (
|
||||
|
Reference in New Issue
Block a user