Compare commits

...

1 Commits

12 changed files with 353 additions and 122 deletions

View File

@ -27,9 +27,9 @@ from torch._inductor.fx_passes.post_grad import post_grad_passes
from torch._inductor.test_case import run_tests, TestCase from torch._inductor.test_case import run_tests, TestCase
from torch._inductor.utils import run_and_get_code, run_and_get_cpp_code from torch._inductor.utils import run_and_get_code, run_and_get_cpp_code
from torch._inductor.virtualized import V from torch._inductor.virtualized import V
from torch.testing._internal.common_utils import IS_MACOS from torch.testing._internal.common_utils import IS_MACOS, skipIfRocm
from torch.testing._internal.triton_utils import requires_cuda_and_triton from torch.testing._internal.triton_utils import requires_cuda_and_triton
from torch.profiler import profile, ProfilerActivity
try: try:
from .test_aot_inductor_utils import AOTIRunnerUtil from .test_aot_inductor_utils import AOTIRunnerUtil
@ -941,5 +941,63 @@ copy_tests(
) )
from torch.profiler._utils import _enrich_profiler_traces
class TestProfilerStackTraceAugmentation(TestCase):
"""
Test that profiler events are correctly augmented with stack traces
from both FX metadata and inductor kernel stack traces.
"""
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@skipIfRocm
@torch.fx.experimental._config.patch("enrich_profiler_metadata", True)
@config.patch("fallback_by_default", True) # TODO update the config patch to inductor lite mode
@torch.compiler.config.patch("force_disable_caches", True)
def test_profiler_inductor_stack_trace_augmentation(self):
"""
Test that map_recorded_events_to_aten_ops_with_stack_trace correctly
augments profiler events with stack traces from inductor kernel metadata.
"""
# Test model similar to test.py
class TestModel(torch.nn.Module):
def forward(self, c):
d = c * 2
d = d + 1
return d
device = "cuda"
model = TestModel().to(device)
c = torch.randn((64, 32), device=device)
# Force disable caches to ensure fresh compilation
torch.compiler.config.force_disable_caches = True
# Compile the model
compiled_model = torch.compile(model, fullgraph=True)
# Warmup
for _ in range(3):
_ = compiled_model(c)
# Profile with the compiled model
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
) as prof:
compiled_model(c)
actual_traces = _enrich_profiler_traces(prof)
self.assertExpectedInline(actual_traces, """\
event=aten::mul node=torch.ops.aten.mul.Tensor:1 stack_trace=d = c * 2
event=cudaLaunchKernel node=torch.ops.aten.mul.Tensor:1 stack_trace=d = c * 2
event=aten::add node=torch.ops.aten.add.Tensor:2 stack_trace=d = d + 1
event=cudaLaunchKernel node=torch.ops.aten.add.Tensor:2 stack_trace=d = d + 1""")
# TODO: add test that when enrich is not turned on there is no recordfast generated.
if __name__ == "__main__": if __name__ == "__main__":
run_tests() run_tests()

View File

@ -76,11 +76,8 @@ from torch.testing._internal.common_utils import (
) )
from torch.testing._internal.jit_utils import JitTestCase from torch.testing._internal.jit_utils import JitTestCase
import json
import tempfile
from torch.profiler import profile, ProfilerActivity from torch.profiler import profile, ProfilerActivity
from torch.profiler._utils import map_recorded_events_to_aten_ops_with_stack_trace from torch.profiler._utils import _enrich_profiler_traces
from torch.autograd.profiler_util import _canonicalize_profiler_events
try: try:
from torchvision import models as torchvision_models from torchvision import models as torchvision_models
@ -208,36 +205,6 @@ def side_effect_func(x: torch.Tensor):
print(x) print(x)
def _enrich_profiler_traces(prof):
"""
Helper function to extract and augment profiler events with stack traces.
Args:
prof: A torch.profiler.profile object
Returns:
A string representing enriched events
"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.json') as f:
trace_file = f.name
prof.export_chrome_trace(trace_file)
with open(trace_file) as f:
trace_data = json.load(f)
map_recorded_events_to_aten_ops_with_stack_trace(
trace_data
)
events = []
for event in trace_data["traceEvents"]:
if "args" in event and "stack_trace" in event["args"]:
events.append(event)
actual_traces = _canonicalize_profiler_events(events)
return actual_traces
class TestFX(JitTestCase): class TestFX(JitTestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()

View File

@ -70,7 +70,7 @@ from .common import (
) )
from .cpp_utils import cexpr from .cpp_utils import cexpr
from .triton_utils import config_of, should_unwrap_unspec_arg, signature_to_meta from .triton_utils import config_of, should_unwrap_unspec_arg, signature_to_meta
from torch.fx.experimental import _config as fx_experimental_config
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Iterator, Sequence from collections.abc import Iterator, Sequence
@ -1120,6 +1120,37 @@ class PythonWrapperCodegen(CodeGen):
# Additional files that are dependent to the wrapper (ex. cubin files) # Additional files that are dependent to the wrapper (ex. cubin files)
self.additional_files = [] self.additional_files = []
# This is used to emit RecordFunctionFast markers that can be matched
# with profiler traces for provenance tracking.
#
# Stores the (kernel_name, debug_handle) tuple
# for the currently being generated kernel.
self.current_kernel_debug_handle: Optional[tuple[str, int]] = None
# set_current_kernel_debug_handle: Flag that controls whether
# write_provenance_debug_handle() should update current_kernel_debug_handle.
# This flag is automatically managed by kernel_debug_handle_context().
self.set_current_kernel_debug_handle: bool = False
@contextlib.contextmanager
def kernel_debug_handle_context(self):
"""
Context manager for kernel debug handle tracking.
self.current_kernel_debug_handle can be updated within the context
with wrapper.write_provenance_debug_handle
and it will be reset after the context
"""
old_flag_value = self.set_current_kernel_debug_handle
old_handle_value = self.current_kernel_debug_handle
self.set_current_kernel_debug_handle = True
try:
yield
finally:
self.set_current_kernel_debug_handle = old_flag_value
self.current_kernel_debug_handle = old_handle_value
@staticmethod @staticmethod
def create( def create(
is_subgraph: bool, is_subgraph: bool,
@ -1510,8 +1541,27 @@ class PythonWrapperCodegen(CodeGen):
def generate_end(self, result: IndentedBuffer) -> None: def generate_end(self, result: IndentedBuffer) -> None:
return return
def generate_record_function_start(self) -> Optional[str]:
record_func = self.current_kernel_debug_handle and fx_experimental_config.enrich_profiler_metadata
if record_func:
assert self.current_kernel_debug_handle
kernel_name, debug_handle = self.current_kernel_debug_handle
kernel_debug_handle = f"{kernel_name}:{debug_handle}"
self.writeline(
f"_rf_enter = torch._C._profiler._RecordFunctionFast('## inductor_kernel:{kernel_debug_handle} ##'); _rf_enter.__enter__()"
)
return "_rf_enter"
else:
return None
def generate_record_function_end(self, record_func_var: Optional[str]):
if record_func_var:
self.writeline(f"{record_func_var}.__exit__(None, None, None)")
def generate_fallback_kernel(self, node: ir.FallbackKernel) -> None: def generate_fallback_kernel(self, node: ir.FallbackKernel) -> None:
record_func_var = self.generate_record_function_start()
self.writeline(ExternKernelAllocLine(self, node)) self.writeline(ExternKernelAllocLine(self, node))
self.generate_record_function_end(record_func_var)
def generate_extern_kernel_alloc(self, node: ir.ExternKernelAlloc): def generate_extern_kernel_alloc(self, node: ir.ExternKernelAlloc):
node.codegen_comment(self) node.codegen_comment(self)
@ -1671,7 +1721,9 @@ class PythonWrapperCodegen(CodeGen):
raw_args: Sequence[Any], raw_args: Sequence[Any],
outputs: Sequence[ir.Buffer], outputs: Sequence[ir.Buffer],
) -> None: ) -> None:
record_func_var = self.generate_record_function_start()
self.writeline(f"{buf_name} = {python_kernel_name}({', '.join(get_args())})") self.writeline(f"{buf_name} = {python_kernel_name}({', '.join(get_args())})")
self.generate_record_function_end(record_func_var)
def generate(self, is_inference): def generate(self, is_inference):
with dynamo_timed("PythonWrapperCodegen.generate"): with dynamo_timed("PythonWrapperCodegen.generate"):
@ -3142,6 +3194,8 @@ class PythonWrapperCodegen(CodeGen):
self.writeline( self.writeline(
f"{self.comment} [Provenance debug handles] {kernel_name}:{debug_handle}" f"{self.comment} [Provenance debug handles] {kernel_name}:{debug_handle}"
) )
if self.set_current_kernel_debug_handle:
self.current_kernel_debug_handle = (kernel_name, debug_handle)
def make_buffer_reuse(self, old: BufferLike, new: BufferLike, delete_old: bool): def make_buffer_reuse(self, old: BufferLike, new: BufferLike, delete_old: bool):
assert old.get_dtype() == new.get_dtype() assert old.get_dtype() == new.get_dtype()

View File

@ -98,6 +98,8 @@ from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols, SymExpr
from torch.fx.passes.fake_tensor_prop import FakeTensorProp from torch.fx.passes.fake_tensor_prop import FakeTensorProp
from torch.monitor import _WaitCounter from torch.monitor import _WaitCounter
from torch.utils._ordered_set import OrderedSet from torch.utils._ordered_set import OrderedSet
from torch.fx.experimental import _config as fx_experimental_config
from .._dynamo.backends.common import aot_autograd from .._dynamo.backends.common import aot_autograd
from .._dynamo.exc import ShortenTraceback, SkipFrame from .._dynamo.exc import ShortenTraceback, SkipFrame
@ -1530,7 +1532,10 @@ class _InProcessFxCompile(FxCompile):
# Dump provenance artifacts for debugging trace # Dump provenance artifacts for debugging trace
inductor_provenance_tracking_node_mappings = None inductor_provenance_tracking_node_mappings = None
inductor_kernel_stack_trace_str = None inductor_kernel_stack_trace_str = None
if config.trace.provenance_tracking_level != 0: if (
config.trace.provenance_tracking_level != 0
or fx_experimental_config.enrich_profiler_metadata
):
inductor_provenance_tracking_node_mappings = json.dumps( inductor_provenance_tracking_node_mappings = json.dumps(
torch._inductor.debug.dump_inductor_provenance_info() torch._inductor.debug.dump_inductor_provenance_info()
) )

View File

@ -1179,6 +1179,8 @@ torchinductor_worker_logpath: str = Config(
default="", default="",
) )
fallback_by_default: bool = False
# config specific to codegen/cpp.py # config specific to codegen/cpp.py
class cpp: class cpp:

View File

@ -1106,7 +1106,7 @@ def set_kernel_post_grad_provenance_tracing(
Returns a unique int debug handler for each call to this function. Returns a unique int debug handler for each call to this function.
""" """
if config.trace.provenance_tracking_level == 0: if config.trace.provenance_tracking_level == 0 and not config.fallback_by_default:
return None return None
try: try:

View File

@ -1628,6 +1628,7 @@ class GraphLowering(torch.fx.Interpreter):
"inductor", "lowerings", lambda: repr(n) "inductor", "lowerings", lambda: repr(n)
) )
) )
or (n.op == "call_function" and config.fallback_by_default)
): ):
debug("fallback_handler") debug("fallback_handler")
result = fallback_handler(n.target, add_to_fallback_set=False)( result = fallback_handler(n.target, add_to_fallback_set=False)(

View File

@ -8079,27 +8079,28 @@ class FallbackKernel(ExternKernelAlloc):
for v, a in zip(args_iter, kernel._schema.arguments) for v, a in zip(args_iter, kernel._schema.arguments)
) )
self.codegen_comment(wrapper) with wrapper.kernel_debug_handle_context():
if self.use_runtime_dispatch: self.codegen_comment(wrapper)
exported_args = self.export_extern_kernel_node() if self.use_runtime_dispatch:
assert self.python_kernel_name is not None exported_args = self.export_extern_kernel_node()
assert self.op_overload is not None assert self.python_kernel_name is not None
assert self.op_overload is not None
wrapper.generate_fallback_kernel_with_runtime_lookup( wrapper.generate_fallback_kernel_with_runtime_lookup(
self.get_name(), self.get_name(),
self.python_kernel_name, self.python_kernel_name,
lambda: [*self.codegen_args(), *self.codegen_kwargs()], lambda: [*self.codegen_args(), *self.codegen_kwargs()],
self.op_overload, self.op_overload,
exported_args, exported_args,
# NOTE: [special handling of all_reduce_coalesced_'s return value] # NOTE: [special handling of all_reduce_coalesced_'s return value]
self.outputs if self.outputs else self.mutation_outputs, self.outputs if self.outputs else self.mutation_outputs,
) )
else: else:
wrapper.generate_fallback_kernel(self) wrapper.generate_fallback_kernel(self)
if isinstance(self.layout, Layout): if isinstance(self.layout, Layout):
self.codegen_size_asserts(wrapper) self.codegen_size_asserts(wrapper)
self.codegen_alignment_asserts(wrapper) self.codegen_alignment_asserts(wrapper)
self.codegen_memory_tracking(wrapper) self.codegen_memory_tracking(wrapper)
self.codegen_unbacked_symbol_defs(wrapper) self.codegen_unbacked_symbol_defs(wrapper)

View File

@ -25,6 +25,8 @@ from __future__ import annotations
import dataclasses import dataclasses
import logging import logging
import os import os
import random
import string
from functools import partial from functools import partial
from typing import Any, Optional, TYPE_CHECKING, TypeAlias, Union from typing import Any, Optional, TYPE_CHECKING, TypeAlias, Union
@ -70,6 +72,10 @@ if TYPE_CHECKING:
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
# Used for profiler post-processing to match
# for the same compiled run
CALL_COMPILED_PREFIX = "Call CompiledFxGraph"
@dataclasses.dataclass @dataclasses.dataclass
class OutputCode: class OutputCode:
@ -612,9 +618,18 @@ class CompiledFxGraph(OutputCode):
try: try:
# Checking the profiler directly is faster than nullcontext # Checking the profiler directly is faster than nullcontext
if torch.autograd.profiler._is_profiler_enabled: if torch.autograd.profiler._is_profiler_enabled:
with record_function( # generate a random string to represent this unique run if no cache key
f"## Call CompiledFxGraph {self._fx_graph_cache_key} ##" run_key = (
): self._fx_graph_cache_key
if self._fx_graph_cache_key
else "".join(random.choices(string.ascii_lowercase, k=51))
)
run_name = f"{CALL_COMPILED_PREFIX} {run_key}"
if self.inductor_provenance_stack_traces_str:
torch.fx.traceback._register_fx_metadata(
run_name, self.inductor_provenance_stack_traces_str
)
with record_function(f"## {run_name} ##"):
return self.current_callable(inputs) return self.current_callable(inputs)
else: else:
return self.current_callable(inputs) return self.current_callable(inputs)

View File

@ -1224,43 +1224,3 @@ def _build_table(
f"time total: {override_time_unit(sum_self_device_time_total, _format_time(sum_self_device_time_total), time_unit)}" f"time total: {override_time_unit(sum_self_device_time_total, _format_time(sum_self_device_time_total), time_unit)}"
) )
return "".join(result) return "".join(result)
# Collect all events with stack traces and format them canonically
def _canonicalize_profiler_events(events):
"""
Extract and format all events with stack traces in a canonical way
for deterministic testing.
"""
events_with_traces = []
for event in events:
# Extract relevant fields
event_name = event.get("name", "")
node_name = event["args"].get("node_name", "")
stack_trace = event["args"].get("stack_trace", "")
# Get the last non-empty line of the stack trace
lines = [s.strip() for s in stack_trace.split("\n") if s.strip()]
stack_trace = lines[-1] if lines else ""
events_with_traces.append(
{
"event_name": event_name[:20],
"node_name": node_name,
"stack_trace": stack_trace,
"start_time": event.get("ts", 0),
}
)
# Sort by node_name for deterministic ordering
events_with_traces.sort(key=lambda x: x["start_time"])
# Format as a string
lines: list[str] = []
for evt in events_with_traces:
lines.append(
f"event={evt['event_name']} node={evt['node_name']} stack_trace={evt['stack_trace']}"
)
return "\n".join(lines)

View File

@ -43,10 +43,10 @@ should_preserve_node_meta = False
# ============================================================================= # =============================================================================
# Global in-memory registry for FX metadata # Global in-memory registry for FX metadata
# Maps module_name -> metadata dict containing lineno_map and node_metadata # Maps module_name -> metadata dict containing lineno_map and node_metadata
_FX_METADATA_REGISTRY: dict[str, dict[str, Any]] = {} _FX_METADATA_REGISTRY: dict[str, str | dict[str, Any]] = {}
def _register_fx_metadata(module_name: str, metadata: dict[str, Any]) -> None: def _register_fx_metadata(module_name: str, metadata: str | dict[str, Any]) -> None:
""" """
Register FX metadata in the global in-memory registry. Register FX metadata in the global in-memory registry.
@ -55,7 +55,7 @@ def _register_fx_metadata(module_name: str, metadata: dict[str, Any]) -> None:
Args: Args:
module_name: The module identifier (content-addressed filename) module_name: The module identifier (content-addressed filename)
metadata: Metadata dict containing lineno_map, node_metadata, and source_code metadata: Metadata dict containing lineno_map, node_metadata, and source_code. If a str, it's a json dump that can be json loaded as a dict.
""" """
# TODO: add logging to tlparse # TODO: add logging to tlparse
_FX_METADATA_REGISTRY[module_name] = metadata _FX_METADATA_REGISTRY[module_name] = metadata

View File

@ -1,9 +1,11 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
import functools import functools
import json
import operator import operator
import re import re
from collections import deque from collections import deque
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum
from typing import Any, Literal, Optional, TYPE_CHECKING from typing import Any, Literal, Optional, TYPE_CHECKING
from torch.autograd.profiler import profile from torch.autograd.profiler import profile
@ -402,13 +404,31 @@ def _init_for_cuda_graphs() -> None:
pass pass
class ContextType(Enum):
"""Types of contexts in the profiler stack."""
FX_GRAPH = "filename"
FX_NODE = "node"
COMPILED_GRAPH = "compiled_graph"
INDUCTOR_NODE = "inductor_node"
def get_parent_context_type(context_type: ContextType) -> Optional[ContextType]:
if context_type == ContextType.FX_NODE:
return ContextType.FX_GRAPH
elif context_type == ContextType.INDUCTOR_NODE:
return ContextType.COMPILED_GRAPH
else:
return None
@dataclass @dataclass
class TimelineEvent: class TimelineEvent:
"""Represents an event in the profiler timeline.""" """Represents an event in the profiler timeline."""
timestamp: int timestamp: int
event_type: Literal["start", "end", "regular"] event_type: Literal["start", "end", "regular"]
marker_type: Optional[Literal["filename", "node"]] marker_type: Optional[ContextType]
identifier: Optional[str | int] identifier: Optional[str | int]
event: dict[str, Any] event: dict[str, Any]
@ -417,7 +437,7 @@ class TimelineEvent:
class ContextStackEntry: class ContextStackEntry:
"""Represents a context (filename or node) in the stack.""" """Represents a context (filename or node) in the stack."""
context_type: Literal["filename", "node"] context_type: ContextType
identifier: str | int identifier: str | int
metadata: Optional[dict] metadata: Optional[dict]
tid: Optional[int] = None # Thread ID associated with this context tid: Optional[int] = None # Thread ID associated with this context
@ -438,6 +458,8 @@ def map_recorded_events_to_aten_ops_with_stack_trace(traced_data):
Returns: Returns:
Dict mapping recorded event names to their aten operations with added stack traces Dict mapping recorded event names to their aten operations with added stack traces
""" """
from torch._inductor.output_code import CALL_COMPILED_PREFIX
from torch.fx.graph_module import FX_GRAPH_MODULE_FILE_PREFIX
from torch.fx.traceback import _FX_METADATA_REGISTRY from torch.fx.traceback import _FX_METADATA_REGISTRY
trace_events = traced_data.get("traceEvents", []) trace_events = traced_data.get("traceEvents", [])
@ -447,7 +469,7 @@ def map_recorded_events_to_aten_ops_with_stack_trace(traced_data):
def is_fx_marker_event(event): def is_fx_marker_event(event):
return ( return (
event.get("cat") == "cpu_op" event.get("cat") in ("cpu_op", "user_annotation")
and event.get("name", "").startswith("## ") and event.get("name", "").startswith("## ")
and event.get("name", "").endswith(" ##") and event.get("name", "").endswith(" ##")
) )
@ -469,14 +491,27 @@ def map_recorded_events_to_aten_ops_with_stack_trace(traced_data):
if is_fx_marker_event(event): if is_fx_marker_event(event):
content = event["name"][3:-3] content = event["name"][3:-3]
if content.endswith(".py"): # Try different event types
append_fx_marker_event("filename", content, event) if content.startswith(FX_GRAPH_MODULE_FILE_PREFIX) and content.endswith(
".py"
):
# FX graph event
append_fx_marker_event(ContextType.FX_GRAPH, content, event)
elif content.startswith(CALL_COMPILED_PREFIX):
# Inductor compiled graph event
append_fx_marker_event(ContextType.COMPILED_GRAPH, content, event)
elif content.startswith("inductor_kernel:"):
append_fx_marker_event(
ContextType.INDUCTOR_NODE, content[len("inductor_kernel:") :], event
)
else: else:
# Try to parse as node index for FX graph
# TODO: change to start with fx_node
try: try:
node_index = int(content) node_index = int(content)
append_fx_marker_event(ContextType.FX_NODE, node_index, event)
except ValueError: except ValueError:
pass pass
append_fx_marker_event("node", node_index, event) # type: ignore[possibly-undefined]
else: else:
# Regular event that needs augmentation # Regular event that needs augmentation
@ -495,23 +530,37 @@ def map_recorded_events_to_aten_ops_with_stack_trace(traced_data):
case "start": case "start":
assert timeline_event.identifier is not None assert timeline_event.identifier is not None
if timeline_event.marker_type == "filename": if timeline_event.marker_type in (
ContextType.FX_GRAPH,
ContextType.COMPILED_GRAPH,
):
assert isinstance(timeline_event.identifier, str) assert isinstance(timeline_event.identifier, str)
# Push filename context - query metadata registry on-demand # Push filename context - query metadata registry on-demand
metadata = _FX_METADATA_REGISTRY.get(timeline_event.identifier) metadata = _FX_METADATA_REGISTRY.get(timeline_event.identifier)
tid = timeline_event.event.get("tid") tid = timeline_event.event.get("tid")
# TODO: add get method in traceback to try - catch and get
if isinstance(metadata, str):
metadata = json.loads(metadata)
context_stack.append( context_stack.append(
ContextStackEntry( ContextStackEntry(
"filename", timeline_event.identifier, metadata, tid timeline_event.marker_type,
timeline_event.identifier,
metadata,
tid,
) )
) )
elif timeline_event.marker_type == "node": elif timeline_event.marker_type in (
ContextType.FX_NODE,
ContextType.INDUCTOR_NODE,
):
# Find the current filename from stack # Find the current filename from stack
current_file_metadata = None current_file_metadata = None
tid = timeline_event.event.get("tid") tid = timeline_event.event.get("tid")
parent_type = get_parent_context_type(timeline_event.marker_type)
for ctx_entry in reversed(context_stack): for ctx_entry in reversed(context_stack):
if ( if (
ctx_entry.context_type == "filename" ctx_entry.context_type == parent_type
and ctx_entry.tid == tid and ctx_entry.tid == tid
): ):
current_file_metadata = ctx_entry.metadata current_file_metadata = ctx_entry.metadata
@ -520,14 +569,39 @@ def map_recorded_events_to_aten_ops_with_stack_trace(traced_data):
if current_file_metadata: if current_file_metadata:
node_metadata = current_file_metadata.get("node_metadata", {}) node_metadata = current_file_metadata.get("node_metadata", {})
if timeline_event.identifier in node_metadata: if timeline_event.identifier in node_metadata:
node_meta: Optional[dict] = node_metadata[ if ctx_entry.context_type == ContextType.FX_NODE:
timeline_event.identifier node_meta: Optional[dict] = node_metadata[
] timeline_event.identifier
context_stack.append( ]
ContextStackEntry( context_stack.append(
"node", timeline_event.identifier, node_meta, tid ContextStackEntry(
ContextType.FX_NODE,
timeline_event.identifier,
node_meta,
tid,
)
) )
if timeline_event.marker_type == ContextType.INDUCTOR_NODE:
# Look up stack traces for this kernel
# TODO: make a dictionary that maps from compiled key to stack traces dictionary
stack_traces = current_file_metadata.get(
timeline_event.identifier, []
) )
if stack_traces:
# Store all stack traces as metadata
node_meta: Optional[dict] = {
"stack_trace": stack_traces,
"name": timeline_event.identifier,
}
context_stack.append(
ContextStackEntry(
ContextType.INDUCTOR_NODE,
timeline_event.identifier,
node_meta,
tid,
)
)
case "end": case "end":
# Pop from stack - search backwards to find matching context # Pop from stack - search backwards to find matching context
@ -551,7 +625,10 @@ def map_recorded_events_to_aten_ops_with_stack_trace(traced_data):
for ctx_entry in reversed(context_stack): for ctx_entry in reversed(context_stack):
# Only apply metadata from contexts with matching tid # Only apply metadata from contexts with matching tid
if ctx_entry.tid == event_tid: if ctx_entry.tid == event_tid:
if ctx_entry.context_type == "node" and ctx_entry.metadata: if (
ctx_entry.context_type == ContextType.FX_NODE
and ctx_entry.metadata
):
current_stack_trace = ctx_entry.metadata.get( current_stack_trace = ctx_entry.metadata.get(
"stack_trace", "No model stack trace available" "stack_trace", "No model stack trace available"
) )
@ -559,6 +636,19 @@ def map_recorded_events_to_aten_ops_with_stack_trace(traced_data):
# Do we want to only attach the stack trace of the lowest node or stack trace of all nodes # Do we want to only attach the stack trace of the lowest node or stack trace of all nodes
# if nodes are nested, e.g. in nested graph modules # if nodes are nested, e.g. in nested graph modules
break break
elif (
ctx_entry.context_type == ContextType.INDUCTOR_NODE
and ctx_entry.metadata
):
# For inductor nodes, stack_trace is a list of traces
stack_traces_list = ctx_entry.metadata.get(
"stack_trace", []
)
if stack_traces_list:
# Store as a list - each trace gets its own entry
current_stack_trace = stack_traces_list
current_node_name = ctx_entry.metadata.get("name", "")
break
# Augment the event # Augment the event
if current_stack_trace or current_node_name: if current_stack_trace or current_node_name:
@ -567,3 +657,81 @@ def map_recorded_events_to_aten_ops_with_stack_trace(traced_data):
args["stack_trace"] = current_stack_trace args["stack_trace"] = current_stack_trace
if current_node_name: if current_node_name:
args["node_name"] = current_node_name args["node_name"] = current_node_name
import tempfile
import os
# Collect all events with stack traces and format them canonically
def _canonicalize_profiler_events(events):
"""
Extract and format all events with stack traces in a canonical way
for deterministic testing.
"""
events_with_traces = []
for event in events:
# Extract relevant fields
event_name = event.get("name", "")
node_name = event["args"].get("node_name", "")
stack_trace = event["args"].get("stack_trace", "")
if isinstance(stack_trace, list):
stack_trace = "\n".join(stack_trace)
# Get the last non-empty line of the stack trace
lines = [s.strip() for s in stack_trace.split("\n") if s.strip()]
stack_trace = lines[-1] if lines else ""
events_with_traces.append(
{
"event_name": event_name[:20],
"node_name": node_name,
"stack_trace": stack_trace,
"start_time": event.get("ts", 0),
}
)
# Sort by node_name for deterministic ordering
events_with_traces.sort(key=lambda x: x["start_time"])
# Format as a string
lines: list[str] = []
for evt in events_with_traces:
lines.append(
f"event={evt['event_name']} node={evt['node_name']} stack_trace={evt['stack_trace']}"
)
return "\n".join(lines)
def _enrich_profiler_traces(prof):
"""
Helper function to extract and augment profiler events with stack traces.
Args:
prof: A torch.profiler.profile object
Returns:
A string representing enriched events
"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
trace_file = f.name
try:
prof.export_chrome_trace(trace_file)
with open(trace_file) as f:
trace_data = json.load(f)
map_recorded_events_to_aten_ops_with_stack_trace(trace_data)
events = []
for event in trace_data["traceEvents"]:
if "args" in event and "stack_trace" in event["args"]:
events.append(event)
actual_traces = _canonicalize_profiler_events(events)
return actual_traces
finally:
if os.path.exists(trace_file):
os.remove(trace_file)