Compare commits

...

1 Commits

Author SHA1 Message Date
71a8a769c5 Add stack trace to kineto trace for aot_eager 2025-09-30 17:53:35 -07:00
3 changed files with 71 additions and 3 deletions

View File

@ -1,6 +1,10 @@
import os
# Whether to disable showing progress on compilation passes
# Need to add a new config otherwise will get a circular import if dynamo config is imported here
disable_progress = True
# If True this also shows the node names in each pass, for small models this is great but larger models it's quite noisy
verbose_progress = False
profiler_interpreter_stack_trace = os.environ.get("TORCH_PROFILE_INTERPRETER_STACK_TRACE", "0") == "1"

View File

@ -1,12 +1,15 @@
# mypy: allow-untyped-defs
import inspect
from contextlib import contextmanager
from contextlib import contextmanager, nullcontext
from typing import Any, Optional, TYPE_CHECKING, Union
import torch
import torch.fx.traceback as fx_traceback
from torch._logging import trace_structured
from torch.hub import tqdm
from torch.profiler import profile, record_function, ProfilerActivity
import torch._C._profiler as _profiler
import json
from . import config
from ._compatibility import compatibility
@ -161,6 +164,16 @@ class Interpreter:
delay=0,
)
graph_id = id(self.graph)
if config.profiler_interpreter_stack_trace:
stack_traces = {}
for node in self.graph.nodes:
if node.stack_trace:
stack_traces[f"## {node.name}:{graph_id} interpreter ##"] = node.stack_trace.replace("\"", "'")
# add stack traces to profiler metadata
torch.autograd._add_metadata_json(f"node_stack_traces:{graph_id}", json.dumps(stack_traces))
for node in self.graph.nodes:
pbar.update(1)
if node in self.env:
@ -169,9 +182,13 @@ class Interpreter:
# where the caller has pre-populated `env` with
# values for a subset of the program.
continue
profiler_context = nullcontext()
if config.profiler_interpreter_stack_trace:
profiler_context = torch.profiler.record_function(f"## {node.name}:{graph_id} interpreter ##")
try:
self.env[node] = self.run_node(node)
with profiler_context:
self.env[node] = self.run_node(node)
except Exception as e:
if self.extra_traceback:
msg = f"While executing {node.format_node()}"

View File

@ -5,6 +5,7 @@ import traceback
from contextlib import contextmanager
from enum import Enum
from typing import Any, Optional, Union
import json
from torch._utils_internal import signpost_event
@ -396,3 +397,49 @@ def get_graph_provenance_json(graph: Graph) -> dict[str, Any]:
},
)
return {}
def populate_stack_traces_to_kineto_trace(file_name: str, update_file = True):
"""
Process traces by attaching stack traces to user_annotation entries.
Args:
file_name (str): The filename of the exported kineto trace json.
update_file (bool): Whether to update the kineto trace json file with the stack traces.
Returns:
dict: Modified trace data with stack traces attached to matching entries
"""
trace_data = json.load(open(file_name, 'r'))
all_stack_traces = {}
# Get the trace events
for key in trace_data.keys():
if not key.startswith("node_stack_traces"):
continue
# Get the node stack traces mapping
node_stack_traces = trace_data.get(key, {})
all_stack_traces.update(node_stack_traces)
if len(all_stack_traces) == 0:
log.warning("No stack traces found in kineto trace data")
return trace_data
trace_events = trace_data.get("traceEvents", [])
# Process each trace event
for event in trace_events:
# Check if this is a user_annotation event
if event.get("cat") == "user_annotation":
event_name = event.get("name")
# If the event name matches a node in node_stack_traces, attach the stack trace
if event_name in all_stack_traces:
event["args"]["stack_trace"] = all_stack_traces[event_name]
if update_file:
json.dump(trace_data, open(file_name, 'w'))
return trace_data