Compare commits

...

1 Commits

Author SHA1 Message Date
8e824c2c6b Reapply "Add model code stack trace to torch.profile (#166677)"
This reverts commit c86540f12038ffc4a3c9ecdbecb01ce73e0967c9.

ghstack-source-id: 76c16d9c91217dc1871a825c4748e05608b33daf
Pull-Request: https://github.com/pytorch/pytorch/pull/167107
2025-11-05 11:05:17 -08:00
6 changed files with 429 additions and 5 deletions

View File

@ -23,7 +23,7 @@ torch.fx.graph.Graph.node_copy(self, node: torch.fx.node.Node, arg_transform: Ca
torch.fx.graph.Graph.output(self, result: 'Argument', type_expr: Optional[Any] = None)
torch.fx.graph.Graph.placeholder(self, name: str, type_expr: Optional[Any] = None, default_value: Any) -> torch.fx.node.Node
torch.fx.graph.Graph.print_tabular(self)
torch.fx.graph.Graph.python_code(self, root_module: str, verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False, expanded_def: bool = False) -> torch.fx.graph.PythonCode
torch.fx.graph.Graph.python_code(self, root_module: str, verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False, expanded_def: bool = False, record_func: bool = False) -> torch.fx.graph.PythonCode
torch.fx.graph_module.GraphModule.__init__(self, root: Union[torch.nn.modules.module.Module, Dict[str, Any]], graph: torch.fx.graph.Graph, class_name: str = 'GraphModule')
torch.fx.graph_module.GraphModule.add_submodule(self, target: str, m: torch.nn.modules.module.Module) -> bool
torch.fx.graph_module.GraphModule.delete_all_unused_submodules(self) -> None

View File

@ -72,9 +72,16 @@ from torch.testing._internal.common_utils import (
IS_WINDOWS,
run_tests,
skipIfTorchDynamo,
skipIfRocm,
)
from torch.testing._internal.jit_utils import JitTestCase
import json
import tempfile
from torch.profiler import profile, ProfilerActivity
from torch.profiler._utils import map_recorded_events_to_aten_ops_with_stack_trace
from torch.autograd.profiler_util import _canonicalize_profiler_events
try:
from torchvision import models as torchvision_models
@ -201,6 +208,36 @@ def side_effect_func(x: torch.Tensor):
print(x)
def _enrich_profiler_traces(prof):
"""
Helper function to extract and augment profiler events with stack traces.
Args:
prof: A torch.profiler.profile object
Returns:
A string representing enriched events
"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.json') as f:
trace_file = f.name
prof.export_chrome_trace(trace_file)
with open(trace_file) as f:
trace_data = json.load(f)
map_recorded_events_to_aten_ops_with_stack_trace(
trace_data
)
events = []
for event in trace_data["traceEvents"]:
if "args" in event and "stack_trace" in event["args"]:
events.append(event)
actual_traces = _canonicalize_profiler_events(events)
return actual_traces
class TestFX(JitTestCase):
def setUp(self):
super().setUp()
@ -4212,6 +4249,153 @@ def forward(self, args_list: List[torch.Tensor]){maybe_return_annotation}:
# recorver mutable checking flag
torch.fx.proxy.TracerBase.check_mutable_operations = orig_tracer_mutable_flag
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@skipIfRocm
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
def test_profiler_stack_trace_augmentation(self):
"""
Test that map_recorded_events_to_aten_ops_with_stack_trace correctly
augments profiler events with stack traces from FX metadata registry.
"""
# Simple test model
class TestModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(10, 16)
self.relu = torch.nn.ReLU()
self.linear2 = torch.nn.Linear(16, 10)
def forward(self, x):
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
return x
model = TestModel().cuda()
# Compile the model
compiled_model = torch.compile(model, backend="aot_eager", fullgraph=True)
# Warmup
for _ in range(3):
_ = compiled_model(torch.randn(10, 10, device="cuda"))
# Profile with the compiled model
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
) as prof:
result = compiled_model(torch.randn(10, 10, device="cuda"))
actual_traces = _enrich_profiler_traces(prof)
self.assertExpectedInline(actual_traces, """\
event=aten::t node=t stack_trace=x = self.linear1(x)
event=aten::transpose node=t stack_trace=x = self.linear1(x)
event=aten::as_strided node=t stack_trace=x = self.linear1(x)
event=aten::addmm node=addmm stack_trace=x = self.linear1(x)
event=cudaLaunchKernel node=addmm stack_trace=x = self.linear1(x)
event=aten::relu node=relu stack_trace=x = self.relu(x)
event=aten::clamp_min node=relu stack_trace=x = self.relu(x)
event=cudaLaunchKernel node=relu stack_trace=x = self.relu(x)
event=aten::t node=t_1 stack_trace=x = self.linear2(x)
event=aten::transpose node=t_1 stack_trace=x = self.linear2(x)
event=aten::as_strided node=t_1 stack_trace=x = self.linear2(x)
event=aten::addmm node=addmm_1 stack_trace=x = self.linear2(x)
event=cudaLaunchKernel node=addmm_1 stack_trace=x = self.linear2(x)"""
)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@skipIfRocm
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
def test_profiler_multiple_modules(self):
"""
Test that multiple compiled modules under the same profiler session
have their events correctly augmented with stack traces.
"""
class ModelA(torch.nn.Module):
def forward(self, x):
return x + 1
class ModelB(torch.nn.Module):
def forward(self, x):
return x - 1
model_a = ModelA().cuda()
model_b = ModelB().cuda()
# Compile both models
compiled_a = torch.compile(model_a, backend="aot_eager", fullgraph=True)
compiled_b = torch.compile(model_b, backend="aot_eager", fullgraph=True)
# Warmup
for _ in range(3):
_ = compiled_a(torch.randn(10, 10, device="cuda"))
_ = compiled_b(torch.randn(1, 3, 8, 8, device="cuda"))
# Profile both models in the same session
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
) as prof:
result_a = compiled_a(torch.randn(10, 10, device="cuda"))
result_b = compiled_b(torch.randn(1, 3, 8, 8, device="cuda"))
actual_traces = _enrich_profiler_traces(prof)
self.assertExpectedInline(actual_traces, """\
event=aten::add node=add stack_trace=return x + 1
event=cudaLaunchKernel node=add stack_trace=return x + 1
event=aten::sub node=sub stack_trace=return x - 1
event=cudaLaunchKernel node=sub stack_trace=return x - 1"""
)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@skipIfRocm
@torch._dynamo.config.patch("enrich_profiler_metadata", True)
def test_profiler_nested_graph_modules(self):
"""
Test that nested graph modules (e.g., graph modules calling subgraphs)
have their events correctly augmented with stack traces.
"""
# Model with nested structure
class Mod(torch.nn.Module):
def __init__(self):
super().__init__()
self.c = 5
@torch.compiler.nested_compile_region
def forward(self, x, y):
m = torch.mul(x, y)
s = m.sin()
a = s + self.c
return a
model = Mod().cuda()
# Compile the model (this may create nested graph modules)
compiled_model = torch.compile(model, backend="aot_eager", fullgraph=True)
# Warmup
for _ in range(3):
_ = compiled_model(torch.randn(10, 10, device="cuda"), torch.randn(10, 10, device="cuda"))
# Profile
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
) as prof:
result = compiled_model(torch.randn(10, 10, device="cuda"), torch.randn(10, 10, device="cuda"))
actual_traces = _enrich_profiler_traces(prof)
self.assertExpectedInline(actual_traces, """\
event=aten::mul node=mul stack_trace=m = torch.mul(x, y)
event=cudaLaunchKernel node=mul stack_trace=m = torch.mul(x, y)
event=aten::sin node=sin stack_trace=s = m.sin()
event=cudaLaunchKernel node=sin stack_trace=s = m.sin()
event=aten::add node=add stack_trace=a = s + self.c
event=cudaLaunchKernel node=add stack_trace=a = s + self.c"""
)
def run_getitem_target():
from torch.fx._symbolic_trace import _wrapped_methods_to_patch

View File

@ -1224,3 +1224,43 @@ def _build_table(
f"time total: {override_time_unit(sum_self_device_time_total, _format_time(sum_self_device_time_total), time_unit)}"
)
return "".join(result)
# Collect all events with stack traces and format them canonically
def _canonicalize_profiler_events(events):
"""
Extract and format all events with stack traces in a canonical way
for deterministic testing.
"""
events_with_traces = []
for event in events:
# Extract relevant fields
event_name = event.get("name", "")
node_name = event["args"].get("node_name", "")
stack_trace = event["args"].get("stack_trace", "")
# Get the last non-empty line of the stack trace
lines = [s.strip() for s in stack_trace.split("\n") if s.strip()]
stack_trace = lines[-1] if lines else ""
events_with_traces.append(
{
"event_name": event_name[:20],
"node_name": node_name,
"stack_trace": stack_trace,
"start_time": event.get("ts", 0),
}
)
# Sort by node_name for deterministic ordering
events_with_traces.sort(key=lambda x: x["start_time"])
# Format as a string
lines: list[str] = []
for evt in events_with_traces:
lines.append(
f"event={evt['event_name']} node={evt['node_name']} stack_trace={evt['stack_trace']}"
)
return "\n".join(lines)

View File

@ -443,6 +443,7 @@ class CodeGen:
colored: bool = False,
# Render each argument on its own line
expanded_def: bool = False,
record_func: bool = False,
) -> PythonCode:
free_vars: list[str] = []
body: list[str] = []
@ -798,6 +799,10 @@ class CodeGen:
return
raise NotImplementedError(f"node: {node.op} {node.target}")
if record_func:
body.append(
"_rf = torch._C._profiler._RecordFunctionFast('## ENTER_GRAPH_PLACEHOLDER_KEY ##'); _rf.__enter__()\n"
)
for i, node in enumerate(nodes):
# NOTE: emit_node does not emit a string with newline. It depends
# on delete_unused_values to append one
@ -807,8 +812,22 @@ class CodeGen:
# node index, which will be deleted later
# after going through _body_transformer
body.append(f"# COUNTER: {i}\n")
do_record = record_func and node.op in (
"call_function",
"call_method",
"call_module",
)
if do_record:
# The double hash ## convention is used by post-processing to find the fx markers
body.append(
f"_rf_{node.name} = torch._C._profiler._RecordFunctionFast('## {i} ##'); _rf_{node.name}.__enter__()\n"
)
emit_node(node)
delete_unused_values(node)
if do_record:
body.append(f"_rf_{node.name}.__exit__(None, None, None)\n")
if record_func:
body.append("_rf.__exit__(None, None, None)\n")
if len(body) == 0:
# If the Graph has no non-placeholder nodes, no lines for the body
@ -1760,6 +1779,7 @@ class Graph:
include_device: bool = False,
colored: bool = False,
expanded_def: bool = False,
record_func: bool = False,
) -> PythonCode:
"""
Turn this ``Graph`` into valid Python code.
@ -1827,6 +1847,7 @@ class Graph:
include_device=include_device,
colored=colored,
expanded_def=expanded_def,
record_func=record_func,
)
def _python_code(
@ -1839,6 +1860,7 @@ class Graph:
include_device: bool = False,
colored: bool = False,
expanded_def: bool = False,
record_func: bool = False,
) -> PythonCode:
return self._codegen._gen_python_code(
self.nodes,
@ -1849,6 +1871,7 @@ class Graph:
include_device=include_device,
colored=colored,
expanded_def=expanded_def,
record_func=record_func,
)
def __str__(self) -> str:

View File

@ -861,14 +861,18 @@ class {module_name}(torch.nn.Module):
if isinstance(self._graph._codegen, _PyTreeCodeGen):
self._in_spec = self._graph._codegen.pytree_info.in_spec
self._out_spec = self._graph._codegen.pytree_info.out_spec
python_code = self._graph.python_code(root_module="self")
from torch._dynamo import config as dynamo_config
python_code = self._graph.python_code(
root_module="self", record_func=dynamo_config.enrich_profiler_metadata
)
self._code = python_code.src
self._lineno_map = python_code._lineno_map
self._prologue_start = python_code._prologue_start
cls = type(self)
co_fields = self._graph._co_fields if hasattr(self._graph, "_co_fields") else {}
from torch._dynamo import config as dynamo_config
if dynamo_config.enrich_profiler_metadata:
# Generate metadata and register for profiler augmentation
@ -885,7 +889,6 @@ class {module_name}(torch.nn.Module):
# This ensures the same code+metadata always generates the same filename
hash_value = _metadata_hash(self._code, node_metadata)
file_stem = f"{FX_GRAPH_MODULE_FILE_PREFIX}_{hash_value}"
filename = f"{file_stem}.py"
# Only include co_filename to use it directly as the cache key
@ -905,6 +908,13 @@ class {module_name}(torch.nn.Module):
_register_fx_metadata(filename, metadata)
# Replace the placeholder in generated code with actual filename
# The double hash ## convention is used by post-processing to find the fx markers
self._code = self._code.replace(
"torch._C._profiler._RecordFunctionFast('## ENTER_GRAPH_PLACEHOLDER_KEY ##')",
f"torch._C._profiler._RecordFunctionFast('## {filename} ##')",
)
cls.forward = _forward_from_src(self._code, python_code.globals, co_fields)
# Determine whether this class explicitly defines a __call__ implementation

View File

@ -4,7 +4,7 @@ import operator
import re
from collections import deque
from dataclasses import dataclass
from typing import TYPE_CHECKING
from typing import Any, Literal, Optional, TYPE_CHECKING
from torch.autograd.profiler import profile
from torch.profiler import DeviceType
@ -400,3 +400,170 @@ def _init_for_cuda_graphs() -> None:
with profile():
pass
@dataclass
class TimelineEvent:
"""Represents an event in the profiler timeline."""
timestamp: int
event_type: Literal["start", "end", "regular"]
marker_type: Optional[Literal["filename", "node"]]
identifier: Optional[str | int]
event: dict[str, Any]
@dataclass
class ContextStackEntry:
"""Represents a context (filename or node) in the stack."""
context_type: Literal["filename", "node"]
identifier: str | int
metadata: Optional[dict]
tid: Optional[int] = None # Thread ID associated with this context
def map_recorded_events_to_aten_ops_with_stack_trace(traced_data):
"""
Maps recorded profiler events to their corresponding fx nodes and adds stack traces.
Builds a timeline of all events (regular ops and FX markers for filenames/nodes),
sorts by timestamp, then processes chronologically while maintaining a context stack of active
filename/node scopes. Regular events are augmented with stack traces and node names from the
innermost active context. Runtime is O(n log n) for n events.
Args:
traced_data: Json of profiler events from Chrome trace
Returns:
Dict mapping recorded event names to their aten operations with added stack traces
"""
from torch.fx.traceback import _FX_METADATA_REGISTRY
trace_events = traced_data.get("traceEvents", [])
# Create event timeline
event_timeline: list[TimelineEvent] = []
def is_fx_marker_event(event):
return (
event.get("cat") == "cpu_op"
and event.get("name", "").startswith("## ")
and event.get("name", "").endswith(" ##")
)
def append_fx_marker_event(event_type, identifier, event):
start_ts = event["ts"]
end_ts = start_ts + event["dur"]
event_timeline.append(
TimelineEvent(start_ts, "start", event_type, identifier, event)
)
event_timeline.append(
TimelineEvent(end_ts, "end", event_type, identifier, event)
)
for event in trace_events:
if "ts" not in event or "dur" not in event:
continue
if is_fx_marker_event(event):
content = event["name"][3:-3]
if content.endswith(".py"):
append_fx_marker_event("filename", content, event)
else:
try:
node_index = int(content)
except ValueError:
pass
append_fx_marker_event("node", node_index, event) # type: ignore[possibly-undefined]
else:
# Regular event that needs augmentation
start_ts = event["ts"]
event_timeline.append(TimelineEvent(start_ts, "regular", None, None, event))
# Sort by timestamp
event_timeline.sort(key=lambda x: x.timestamp)
# Process events in chronological order with a stack
context_stack: list[ContextStackEntry] = []
# Invariant: all start event has a corresponding end event
for timeline_event in event_timeline:
match timeline_event.event_type:
case "start":
assert timeline_event.identifier is not None
if timeline_event.marker_type == "filename":
assert isinstance(timeline_event.identifier, str)
# Push filename context - query metadata registry on-demand
metadata = _FX_METADATA_REGISTRY.get(timeline_event.identifier)
tid = timeline_event.event.get("tid")
context_stack.append(
ContextStackEntry(
"filename", timeline_event.identifier, metadata, tid
)
)
elif timeline_event.marker_type == "node":
# Find the current filename from stack
current_file_metadata = None
tid = timeline_event.event.get("tid")
for ctx_entry in reversed(context_stack):
if (
ctx_entry.context_type == "filename"
and ctx_entry.tid == tid
):
current_file_metadata = ctx_entry.metadata
break
if current_file_metadata:
node_metadata = current_file_metadata.get("node_metadata", {})
if timeline_event.identifier in node_metadata:
node_meta: Optional[dict] = node_metadata[
timeline_event.identifier
]
context_stack.append(
ContextStackEntry(
"node", timeline_event.identifier, node_meta, tid
)
)
case "end":
# Pop from stack - search backwards to find matching context
for i in range(len(context_stack) - 1, -1, -1):
ctx_entry = context_stack[i]
if (
timeline_event.marker_type == ctx_entry.context_type
and timeline_event.identifier == ctx_entry.identifier
):
context_stack.pop(i)
break
case "regular":
# Apply metadata from current context stack
# Find the most specific context (node takes precedence over filename)
# Only augment events with the same tid as the file/node event matched
current_stack_trace = None
current_node_name = None
event_tid = timeline_event.event.get("tid")
for ctx_entry in reversed(context_stack):
# Only apply metadata from contexts with matching tid
if ctx_entry.tid == event_tid:
if ctx_entry.context_type == "node" and ctx_entry.metadata:
current_stack_trace = ctx_entry.metadata.get(
"stack_trace", "No model stack trace available"
)
current_node_name = ctx_entry.metadata.get("name", "")
# Do we want to only attach the stack trace of the lowest node or stack trace of all nodes
# if nodes are nested, e.g. in nested graph modules
break
# Augment the event
if current_stack_trace or current_node_name:
args = timeline_event.event.setdefault("args", {})
if current_stack_trace:
args["stack_trace"] = current_stack_trace
if current_node_name:
args["node_name"] = current_node_name