Change IR node's stack traces to return a set of stack traces only (#160701)

Summary: There can be excessive stack trace outputs in TORCH_LOGS="+inductor" when a single line of code corresponds to many post grad nodes, e.g. `self.multihead_attn(x, x, x)`, in that case, we'll see the same stack trace many times in the IR node, spamming the output log. So we change to return a set of stack traces.

Test Plan:
CI

Rollback Plan:

Differential Revision: D80310549

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160701
Approved by: https://github.com/angelayi
This commit is contained in:
Shangdi Yu
2025-08-15 19:31:54 +00:00
committed by PyTorch MergeBot
parent b78968b4d1
commit fa75ba9303

View File

@ -603,12 +603,10 @@ class IRNode:
def get_defining_op(self) -> Optional[Operation]:
return None
def get_stack_traces(self) -> dict[str, str]:
def get_stack_traces(self) -> OrderedSet[str]:
# Return stack traces to user model code
# A single IRNode could correspond to multiple lines of code
# Group nodes by their stack traces to deduplicate
nodes_to_stack_trace = {}
stack_traces: OrderedSet[str] = OrderedSet()
origins = self.origins
if isinstance(self, ExternKernel):
origin_node = self.get_origin_node()
@ -617,7 +615,7 @@ class IRNode:
for node in origins:
if hasattr(node, "stack_trace") and node.stack_trace:
# nodes in the backward graph don't have mapping to pre_grad_graph
nodes_to_stack_trace["post_grad+" + node.name] = node.stack_trace
stack_traces.add(node.stack_trace)
else:
pre_grad_nodes = (
torch._inductor.debug._inductor_post_to_pre_grad_nodes.get(
@ -633,9 +631,8 @@ class IRNode:
)
)
if stack_trace:
nodes_to_stack_trace["pre_grad+" + node_name] = stack_trace
return nodes_to_stack_trace
stack_traces.add(stack_trace)
return stack_traces
def common_repr(self, shorten: bool = True) -> Sequence[str]:
origins = f"origins={getattr(self, 'origins', '')}"
@ -646,8 +643,8 @@ class IRNode:
return [origins]
stack_trace_str = []
for stack_trace in self.get_stack_traces().values():
stack_trace_str.append("stack_traces = {{")
for stack_trace in self.get_stack_traces():
stack_trace_str.append("stack_traces = {")
stack_trace_str += stack_trace.split("\n")
stack_trace_str.append("}")
return [origins] + stack_trace_str