mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Change IR node's stack traces to return a set of stack traces only (#160701)
Summary: There can be excessive stack trace outputs in TORCH_LOGS="+inductor" when a single line of code corresponds to many post grad nodes, e.g. `self.multihead_attn(x, x, x)`, in that case, we'll see the same stack trace many times in the IR node, spamming the output log. So we change to return a set of stack traces. Test Plan: CI Rollback Plan: Differential Revision: D80310549 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160701 Approved by: https://github.com/angelayi
This commit is contained in:
committed by
PyTorch MergeBot
parent
b78968b4d1
commit
fa75ba9303
@ -603,12 +603,10 @@ class IRNode:
|
||||
def get_defining_op(self) -> Optional[Operation]:
|
||||
return None
|
||||
|
||||
def get_stack_traces(self) -> dict[str, str]:
|
||||
def get_stack_traces(self) -> OrderedSet[str]:
|
||||
# Return stack traces to user model code
|
||||
# A single IRNode could correspond to multiple lines of code
|
||||
|
||||
# Group nodes by their stack traces to deduplicate
|
||||
nodes_to_stack_trace = {}
|
||||
stack_traces: OrderedSet[str] = OrderedSet()
|
||||
origins = self.origins
|
||||
if isinstance(self, ExternKernel):
|
||||
origin_node = self.get_origin_node()
|
||||
@ -617,7 +615,7 @@ class IRNode:
|
||||
for node in origins:
|
||||
if hasattr(node, "stack_trace") and node.stack_trace:
|
||||
# nodes in the backward graph don't have mapping to pre_grad_graph
|
||||
nodes_to_stack_trace["post_grad+" + node.name] = node.stack_trace
|
||||
stack_traces.add(node.stack_trace)
|
||||
else:
|
||||
pre_grad_nodes = (
|
||||
torch._inductor.debug._inductor_post_to_pre_grad_nodes.get(
|
||||
@ -633,9 +631,8 @@ class IRNode:
|
||||
)
|
||||
)
|
||||
if stack_trace:
|
||||
nodes_to_stack_trace["pre_grad+" + node_name] = stack_trace
|
||||
|
||||
return nodes_to_stack_trace
|
||||
stack_traces.add(stack_trace)
|
||||
return stack_traces
|
||||
|
||||
def common_repr(self, shorten: bool = True) -> Sequence[str]:
|
||||
origins = f"origins={getattr(self, 'origins', '')}"
|
||||
@ -646,8 +643,8 @@ class IRNode:
|
||||
return [origins]
|
||||
|
||||
stack_trace_str = []
|
||||
for stack_trace in self.get_stack_traces().values():
|
||||
stack_trace_str.append("stack_traces = {{")
|
||||
for stack_trace in self.get_stack_traces():
|
||||
stack_trace_str.append("stack_traces = {")
|
||||
stack_trace_str += stack_trace.split("\n")
|
||||
stack_trace_str.append("}")
|
||||
return [origins] + stack_trace_str
|
||||
|
Reference in New Issue
Block a user