Compare commits

...

5 Commits

17 changed files with 373 additions and 138 deletions

View File

@ -74,13 +74,13 @@ class TestStreams(torch._dynamo.test_case.TestCase):
"""\
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"):
# Annotation: {'stream': None}
# Annotation: {'stream': 0}
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
# Annotation: {'stream': None}
# Annotation: {'stream': 1}
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
# Annotation: {'stream': None}
# Annotation: {'stream': 1}
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, 2); add_1 = None
add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_2, add); add_2 = add = None
return (add_3,)
@ -196,6 +196,7 @@ class <lambda>(torch.nn.Module):
s_exp = fn(*inp)
self.assertEqual(s_act, s_exp)
@requires_cuda
def test_nested_stream_enter_exit(self):
def fn(x, y, s0, s1, s2):
with s1:
@ -229,13 +230,13 @@ class <lambda>(torch.nn.Module):
"""\
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"):
# Annotation: {'stream': None}
# Annotation: {'stream': 1}
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1)
# Annotation: {'stream': None}
# Annotation: {'stream': 2}
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
# Annotation: {'stream': None}
# Annotation: {'stream': 1}
add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add, 2); add = None
return (add_1, add_2)
""",
@ -249,6 +250,7 @@ class <lambda>(torch.nn.Module):
def test_nested_stream_enter_exit_graph_break(self):
pass
@requires_cuda
def test_local_stream_enter_exit(self):
def fn(x, y):
s2 = torch.Stream()
@ -289,6 +291,7 @@ class <lambda>(torch.nn.Module):
""",
)
@requires_cuda
def test_local_stream_nested_enter_exit(self):
def fn(x, y):
s2 = torch.Stream()
@ -331,6 +334,7 @@ class <lambda>(torch.nn.Module):
""",
)
@requires_cuda
def test_stream_with_mutation(self):
def fn(x, y):
s2 = torch.Stream()
@ -380,6 +384,7 @@ class <lambda>(torch.nn.Module):
""",
)
@requires_cuda
def test_stream_backward(self) -> None:
def fn(x, y):
s2 = torch.Stream()

View File

@ -27,9 +27,9 @@ from torch._inductor.fx_passes.post_grad import post_grad_passes
from torch._inductor.test_case import run_tests, TestCase
from torch._inductor.utils import run_and_get_code, run_and_get_cpp_code
from torch._inductor.virtualized import V
from torch.testing._internal.common_utils import IS_MACOS
from torch.testing._internal.common_utils import IS_MACOS, skipIfRocm
from torch.testing._internal.triton_utils import requires_cuda_and_triton
from torch.profiler import profile, ProfilerActivity
try:
from .test_aot_inductor_utils import AOTIRunnerUtil
@ -941,5 +941,63 @@ copy_tests(
)
from torch.profiler._utils import _enrich_profiler_traces
class TestProfilerStackTraceAugmentation(TestCase):
"""
Test that profiler events are correctly augmented with stack traces
from both FX metadata and inductor kernel stack traces.
"""
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@skipIfRocm
@torch.fx.experimental._config.patch("enrich_profiler_metadata", True)
@config.patch("fallback_by_default", True) # TODO update the config patch to inductor lite mode
@torch.compiler.config.patch("force_disable_caches", True)
def test_profiler_inductor_stack_trace_augmentation(self):
"""
Test that map_recorded_events_to_aten_ops_with_stack_trace correctly
augments profiler events with stack traces from inductor kernel metadata.
"""
# Test model similar to test.py
class TestModel(torch.nn.Module):
def forward(self, c):
d = c * 2
d = d + 1
return d
device = "cuda"
model = TestModel().to(device)
c = torch.randn((64, 32), device=device)
# Force disable caches to ensure fresh compilation
torch.compiler.config.force_disable_caches = True
# Compile the model
compiled_model = torch.compile(model, fullgraph=True)
# Warmup
for _ in range(3):
_ = compiled_model(c)
# Profile with the compiled model
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
) as prof:
compiled_model(c)
actual_traces = _enrich_profiler_traces(prof)
self.assertExpectedInline(actual_traces, """\
event=aten::mul node=torch.ops.aten.mul.Tensor:1 stack_trace=d = c * 2
event=cudaLaunchKernel node=torch.ops.aten.mul.Tensor:1 stack_trace=d = c * 2
event=aten::add node=torch.ops.aten.add.Tensor:2 stack_trace=d = d + 1
event=cudaLaunchKernel node=torch.ops.aten.add.Tensor:2 stack_trace=d = d + 1""")
# TODO: add test that when enrich is not turned on there is no recordfast generated.
if __name__ == "__main__":
run_tests()

View File

@ -76,11 +76,8 @@ from torch.testing._internal.common_utils import (
)
from torch.testing._internal.jit_utils import JitTestCase
import json
import tempfile
from torch.profiler import profile, ProfilerActivity
from torch.profiler._utils import map_recorded_events_to_aten_ops_with_stack_trace
from torch.autograd.profiler_util import _canonicalize_profiler_events
from torch.profiler._utils import _enrich_profiler_traces
try:
from torchvision import models as torchvision_models
@ -208,36 +205,6 @@ def side_effect_func(x: torch.Tensor):
print(x)
def _enrich_profiler_traces(prof):
"""
Helper function to extract and augment profiler events with stack traces.
Args:
prof: A torch.profiler.profile object
Returns:
A string representing enriched events
"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.json') as f:
trace_file = f.name
prof.export_chrome_trace(trace_file)
with open(trace_file) as f:
trace_data = json.load(f)
map_recorded_events_to_aten_ops_with_stack_trace(
trace_data
)
events = []
for event in trace_data["traceEvents"]:
if "args" in event and "stack_trace" in event["args"]:
events.append(event)
actual_traces = _canonicalize_profiler_events(events)
return actual_traces
class TestFX(JitTestCase):
def setUp(self):
super().setUp()

View File

@ -1542,7 +1542,7 @@ class OutputGraph(OutputGraphCommon):
)
)
tmp_vars = []
for constructor in reversed(index_to_bytecode_constructor.values()):
for constructor in index_to_bytecode_constructor.values():
constructor(codegen)
var_name = (
self.new_var()

View File

@ -1061,9 +1061,7 @@ class VariableBuilder:
)
set_example_value(stream_proxy.node, value)
var = StreamVariable(
stream_proxy,
value,
source=self.source,
stream_proxy, value, source=self.source, user_object_index=index
)
return self.tx.output.side_effects.track_object_existing(value, var)
elif isinstance(value, (torch._C._SDPAParams)):
@ -3006,14 +3004,16 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe
return SymNodeVariable(proxy, example_value, **options)
elif (
isinstance(example_value, torch.Stream)
and proxy.node.target
in (get_external_object_by_index, torch.accelerator.current_stream)
and proxy.node.target == get_external_object_by_index
) or proxy.node.target in [
device_interface.current_stream
for _, device_interface in get_registered_device_interfaces()
]:
set_example_value(proxy.node, example_value)
return StreamVariable(proxy, example_value, **options)
index = None
if proxy.node.target == get_external_object_by_index:
index = proxy.node.args[0]
return StreamVariable(proxy, example_value, index, **options)
elif (
inspect.isclass(proxy.node.target)
and issubclass(proxy.node.target, torch.Event)

View File

@ -204,11 +204,11 @@ class StreamVariable(StreamContextVariable):
self,
proxy: Proxy,
value: torch.Stream,
user_object_index: Optional[int] = None,
**kwargs: Any,
) -> None:
# Index into the user object table
# used to pass arbitrary objects to the graph
user_object_index = kwargs.pop("user_obj_index", None)
if proxy is not None and "example_value" in proxy.node.meta:
assert proxy.node.meta["example_value"] == value
@ -300,7 +300,7 @@ class StreamVariable(StreamContextVariable):
codegen.append_output(codegen.create_load_const(self.user_object_index))
codegen.extend_output(create_call_function(1, False))
else:
# TODO mlazos: evaluate if we still need this
# This will support the legacy behavior
prefix = f"_stream_{self.device}"
name = codegen.tx.output.install_global_by_id(prefix, self.value)
codegen.append_output(codegen.create_load_global(name, add=True))

View File

@ -838,7 +838,6 @@ class UserDefinedClassVariable(UserDefinedVariable):
proxy=tx.output.create_proxy(
"call_function", get_external_object_by_index, (ind,), {}
),
user_obj_index=ind,
)
else:
tensor_variable = wrap_fx_proxy(

View File

@ -70,7 +70,7 @@ from .common import (
)
from .cpp_utils import cexpr
from .triton_utils import config_of, should_unwrap_unspec_arg, signature_to_meta
from torch.fx.experimental import _config as fx_experimental_config
if TYPE_CHECKING:
from collections.abc import Iterator, Sequence
@ -1120,6 +1120,37 @@ class PythonWrapperCodegen(CodeGen):
# Additional files that are dependent to the wrapper (ex. cubin files)
self.additional_files = []
# This is used to emit RecordFunctionFast markers that can be matched
# with profiler traces for provenance tracking.
#
# Stores the (kernel_name, debug_handle) tuple
# for the currently being generated kernel.
self.current_kernel_debug_handle: Optional[tuple[str, int]] = None
# set_current_kernel_debug_handle: Flag that controls whether
# write_provenance_debug_handle() should update current_kernel_debug_handle.
# This flag is automatically managed by kernel_debug_handle_context().
self.set_current_kernel_debug_handle: bool = False
@contextlib.contextmanager
def kernel_debug_handle_context(self):
"""
Context manager for kernel debug handle tracking.
self.current_kernel_debug_handle can be updated within the context
with wrapper.write_provenance_debug_handle
and it will be reset after the context
"""
old_flag_value = self.set_current_kernel_debug_handle
old_handle_value = self.current_kernel_debug_handle
self.set_current_kernel_debug_handle = True
try:
yield
finally:
self.set_current_kernel_debug_handle = old_flag_value
self.current_kernel_debug_handle = old_handle_value
@staticmethod
def create(
is_subgraph: bool,
@ -1510,8 +1541,27 @@ class PythonWrapperCodegen(CodeGen):
def generate_end(self, result: IndentedBuffer) -> None:
return
def generate_record_function_start(self) -> Optional[str]:
record_func = self.current_kernel_debug_handle and fx_experimental_config.enrich_profiler_metadata
if record_func:
assert self.current_kernel_debug_handle
kernel_name, debug_handle = self.current_kernel_debug_handle
kernel_debug_handle = f"{kernel_name}:{debug_handle}"
self.writeline(
f"_rf_enter = torch._C._profiler._RecordFunctionFast('## inductor_kernel:{kernel_debug_handle} ##'); _rf_enter.__enter__()"
)
return "_rf_enter"
else:
return None
def generate_record_function_end(self, record_func_var: Optional[str]):
if record_func_var:
self.writeline(f"{record_func_var}.__exit__(None, None, None)")
def generate_fallback_kernel(self, node: ir.FallbackKernel) -> None:
record_func_var = self.generate_record_function_start()
self.writeline(ExternKernelAllocLine(self, node))
self.generate_record_function_end(record_func_var)
def generate_extern_kernel_alloc(self, node: ir.ExternKernelAlloc):
node.codegen_comment(self)
@ -1671,7 +1721,9 @@ class PythonWrapperCodegen(CodeGen):
raw_args: Sequence[Any],
outputs: Sequence[ir.Buffer],
) -> None:
record_func_var = self.generate_record_function_start()
self.writeline(f"{buf_name} = {python_kernel_name}({', '.join(get_args())})")
self.generate_record_function_end(record_func_var)
def generate(self, is_inference):
with dynamo_timed("PythonWrapperCodegen.generate"):
@ -3142,6 +3194,8 @@ class PythonWrapperCodegen(CodeGen):
self.writeline(
f"{self.comment} [Provenance debug handles] {kernel_name}:{debug_handle}"
)
if self.set_current_kernel_debug_handle:
self.current_kernel_debug_handle = (kernel_name, debug_handle)
def make_buffer_reuse(self, old: BufferLike, new: BufferLike, delete_old: bool):
assert old.get_dtype() == new.get_dtype()

View File

@ -98,6 +98,8 @@ from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols, SymExpr
from torch.fx.passes.fake_tensor_prop import FakeTensorProp
from torch.monitor import _WaitCounter
from torch.utils._ordered_set import OrderedSet
from torch.fx.experimental import _config as fx_experimental_config
from .._dynamo.backends.common import aot_autograd
from .._dynamo.exc import ShortenTraceback, SkipFrame
@ -1530,7 +1532,10 @@ class _InProcessFxCompile(FxCompile):
# Dump provenance artifacts for debugging trace
inductor_provenance_tracking_node_mappings = None
inductor_kernel_stack_trace_str = None
if config.trace.provenance_tracking_level != 0:
if (
config.trace.provenance_tracking_level != 0
or fx_experimental_config.enrich_profiler_metadata
):
inductor_provenance_tracking_node_mappings = json.dumps(
torch._inductor.debug.dump_inductor_provenance_info()
)

View File

@ -1179,6 +1179,8 @@ torchinductor_worker_logpath: str = Config(
default="",
)
fallback_by_default: bool = False
# config specific to codegen/cpp.py
class cpp:

View File

@ -1106,7 +1106,7 @@ def set_kernel_post_grad_provenance_tracing(
Returns a unique int debug handler for each call to this function.
"""
if config.trace.provenance_tracking_level == 0:
if config.trace.provenance_tracking_level == 0 and not config.fallback_by_default:
return None
try:

View File

@ -1628,6 +1628,7 @@ class GraphLowering(torch.fx.Interpreter):
"inductor", "lowerings", lambda: repr(n)
)
)
or (n.op == "call_function" and config.fallback_by_default)
):
debug("fallback_handler")
result = fallback_handler(n.target, add_to_fallback_set=False)(

View File

@ -8079,27 +8079,28 @@ class FallbackKernel(ExternKernelAlloc):
for v, a in zip(args_iter, kernel._schema.arguments)
)
self.codegen_comment(wrapper)
if self.use_runtime_dispatch:
exported_args = self.export_extern_kernel_node()
assert self.python_kernel_name is not None
assert self.op_overload is not None
with wrapper.kernel_debug_handle_context():
self.codegen_comment(wrapper)
if self.use_runtime_dispatch:
exported_args = self.export_extern_kernel_node()
assert self.python_kernel_name is not None
assert self.op_overload is not None
wrapper.generate_fallback_kernel_with_runtime_lookup(
self.get_name(),
self.python_kernel_name,
lambda: [*self.codegen_args(), *self.codegen_kwargs()],
self.op_overload,
exported_args,
# NOTE: [special handling of all_reduce_coalesced_'s return value]
self.outputs if self.outputs else self.mutation_outputs,
)
else:
wrapper.generate_fallback_kernel(self)
if isinstance(self.layout, Layout):
self.codegen_size_asserts(wrapper)
self.codegen_alignment_asserts(wrapper)
self.codegen_memory_tracking(wrapper)
wrapper.generate_fallback_kernel_with_runtime_lookup(
self.get_name(),
self.python_kernel_name,
lambda: [*self.codegen_args(), *self.codegen_kwargs()],
self.op_overload,
exported_args,
# NOTE: [special handling of all_reduce_coalesced_'s return value]
self.outputs if self.outputs else self.mutation_outputs,
)
else:
wrapper.generate_fallback_kernel(self)
if isinstance(self.layout, Layout):
self.codegen_size_asserts(wrapper)
self.codegen_alignment_asserts(wrapper)
self.codegen_memory_tracking(wrapper)
self.codegen_unbacked_symbol_defs(wrapper)

View File

@ -25,6 +25,8 @@ from __future__ import annotations
import dataclasses
import logging
import os
import random
import string
from functools import partial
from typing import Any, Optional, TYPE_CHECKING, TypeAlias, Union
@ -70,6 +72,10 @@ if TYPE_CHECKING:
log = logging.getLogger(__name__)
# Used for profiler post-processing to match
# for the same compiled run
CALL_COMPILED_PREFIX = "Call CompiledFxGraph"
@dataclasses.dataclass
class OutputCode:
@ -612,9 +618,18 @@ class CompiledFxGraph(OutputCode):
try:
# Checking the profiler directly is faster than nullcontext
if torch.autograd.profiler._is_profiler_enabled:
with record_function(
f"## Call CompiledFxGraph {self._fx_graph_cache_key} ##"
):
# generate a random string to represent this unique run if no cache key
run_key = (
self._fx_graph_cache_key
if self._fx_graph_cache_key
else "".join(random.choices(string.ascii_lowercase, k=51))
)
run_name = f"{CALL_COMPILED_PREFIX} {run_key}"
if self.inductor_provenance_stack_traces_str:
torch.fx.traceback._register_fx_metadata(
run_name, self.inductor_provenance_stack_traces_str
)
with record_function(f"## {run_name} ##"):
return self.current_callable(inputs)
else:
return self.current_callable(inputs)

View File

@ -1224,43 +1224,3 @@ def _build_table(
f"time total: {override_time_unit(sum_self_device_time_total, _format_time(sum_self_device_time_total), time_unit)}"
)
return "".join(result)
# Collect all events with stack traces and format them canonically
def _canonicalize_profiler_events(events):
"""
Extract and format all events with stack traces in a canonical way
for deterministic testing.
"""
events_with_traces = []
for event in events:
# Extract relevant fields
event_name = event.get("name", "")
node_name = event["args"].get("node_name", "")
stack_trace = event["args"].get("stack_trace", "")
# Get the last non-empty line of the stack trace
lines = [s.strip() for s in stack_trace.split("\n") if s.strip()]
stack_trace = lines[-1] if lines else ""
events_with_traces.append(
{
"event_name": event_name[:20],
"node_name": node_name,
"stack_trace": stack_trace,
"start_time": event.get("ts", 0),
}
)
# Sort by node_name for deterministic ordering
events_with_traces.sort(key=lambda x: x["start_time"])
# Format as a string
lines: list[str] = []
for evt in events_with_traces:
lines.append(
f"event={evt['event_name']} node={evt['node_name']} stack_trace={evt['stack_trace']}"
)
return "\n".join(lines)

View File

@ -43,10 +43,10 @@ should_preserve_node_meta = False
# =============================================================================
# Global in-memory registry for FX metadata
# Maps module_name -> metadata dict containing lineno_map and node_metadata
_FX_METADATA_REGISTRY: dict[str, dict[str, Any]] = {}
_FX_METADATA_REGISTRY: dict[str, str | dict[str, Any]] = {}
def _register_fx_metadata(module_name: str, metadata: dict[str, Any]) -> None:
def _register_fx_metadata(module_name: str, metadata: str | dict[str, Any]) -> None:
"""
Register FX metadata in the global in-memory registry.
@ -55,7 +55,7 @@ def _register_fx_metadata(module_name: str, metadata: dict[str, Any]) -> None:
Args:
module_name: The module identifier (content-addressed filename)
metadata: Metadata dict containing lineno_map, node_metadata, and source_code
metadata: Metadata dict containing lineno_map, node_metadata, and source_code. If a str, it's a json dump that can be json loaded as a dict.
"""
# TODO: add logging to tlparse
_FX_METADATA_REGISTRY[module_name] = metadata

View File

@ -1,9 +1,11 @@
# mypy: allow-untyped-defs
import functools
import json
import operator
import re
from collections import deque
from dataclasses import dataclass
from enum import Enum
from typing import Any, Literal, Optional, TYPE_CHECKING
from torch.autograd.profiler import profile
@ -402,13 +404,31 @@ def _init_for_cuda_graphs() -> None:
pass
class ContextType(Enum):
"""Types of contexts in the profiler stack."""
FX_GRAPH = "filename"
FX_NODE = "node"
COMPILED_GRAPH = "compiled_graph"
INDUCTOR_NODE = "inductor_node"
def get_parent_context_type(context_type: ContextType) -> Optional[ContextType]:
if context_type == ContextType.FX_NODE:
return ContextType.FX_GRAPH
elif context_type == ContextType.INDUCTOR_NODE:
return ContextType.COMPILED_GRAPH
else:
return None
@dataclass
class TimelineEvent:
"""Represents an event in the profiler timeline."""
timestamp: int
event_type: Literal["start", "end", "regular"]
marker_type: Optional[Literal["filename", "node"]]
marker_type: Optional[ContextType]
identifier: Optional[str | int]
event: dict[str, Any]
@ -417,7 +437,7 @@ class TimelineEvent:
class ContextStackEntry:
"""Represents a context (filename or node) in the stack."""
context_type: Literal["filename", "node"]
context_type: ContextType
identifier: str | int
metadata: Optional[dict]
tid: Optional[int] = None # Thread ID associated with this context
@ -438,6 +458,8 @@ def map_recorded_events_to_aten_ops_with_stack_trace(traced_data):
Returns:
Dict mapping recorded event names to their aten operations with added stack traces
"""
from torch._inductor.output_code import CALL_COMPILED_PREFIX
from torch.fx.graph_module import FX_GRAPH_MODULE_FILE_PREFIX
from torch.fx.traceback import _FX_METADATA_REGISTRY
trace_events = traced_data.get("traceEvents", [])
@ -447,7 +469,7 @@ def map_recorded_events_to_aten_ops_with_stack_trace(traced_data):
def is_fx_marker_event(event):
return (
event.get("cat") == "cpu_op"
event.get("cat") in ("cpu_op", "user_annotation")
and event.get("name", "").startswith("## ")
and event.get("name", "").endswith(" ##")
)
@ -469,14 +491,27 @@ def map_recorded_events_to_aten_ops_with_stack_trace(traced_data):
if is_fx_marker_event(event):
content = event["name"][3:-3]
if content.endswith(".py"):
append_fx_marker_event("filename", content, event)
# Try different event types
if content.startswith(FX_GRAPH_MODULE_FILE_PREFIX) and content.endswith(
".py"
):
# FX graph event
append_fx_marker_event(ContextType.FX_GRAPH, content, event)
elif content.startswith(CALL_COMPILED_PREFIX):
# Inductor compiled graph event
append_fx_marker_event(ContextType.COMPILED_GRAPH, content, event)
elif content.startswith("inductor_kernel:"):
append_fx_marker_event(
ContextType.INDUCTOR_NODE, content[len("inductor_kernel:") :], event
)
else:
# Try to parse as node index for FX graph
# TODO: change to start with fx_node
try:
node_index = int(content)
append_fx_marker_event(ContextType.FX_NODE, node_index, event)
except ValueError:
pass
append_fx_marker_event("node", node_index, event) # type: ignore[possibly-undefined]
else:
# Regular event that needs augmentation
@ -495,23 +530,37 @@ def map_recorded_events_to_aten_ops_with_stack_trace(traced_data):
case "start":
assert timeline_event.identifier is not None
if timeline_event.marker_type == "filename":
if timeline_event.marker_type in (
ContextType.FX_GRAPH,
ContextType.COMPILED_GRAPH,
):
assert isinstance(timeline_event.identifier, str)
# Push filename context - query metadata registry on-demand
metadata = _FX_METADATA_REGISTRY.get(timeline_event.identifier)
tid = timeline_event.event.get("tid")
# TODO: add get method in traceback to try - catch and get
if isinstance(metadata, str):
metadata = json.loads(metadata)
context_stack.append(
ContextStackEntry(
"filename", timeline_event.identifier, metadata, tid
timeline_event.marker_type,
timeline_event.identifier,
metadata,
tid,
)
)
elif timeline_event.marker_type == "node":
elif timeline_event.marker_type in (
ContextType.FX_NODE,
ContextType.INDUCTOR_NODE,
):
# Find the current filename from stack
current_file_metadata = None
tid = timeline_event.event.get("tid")
parent_type = get_parent_context_type(timeline_event.marker_type)
for ctx_entry in reversed(context_stack):
if (
ctx_entry.context_type == "filename"
ctx_entry.context_type == parent_type
and ctx_entry.tid == tid
):
current_file_metadata = ctx_entry.metadata
@ -520,14 +569,39 @@ def map_recorded_events_to_aten_ops_with_stack_trace(traced_data):
if current_file_metadata:
node_metadata = current_file_metadata.get("node_metadata", {})
if timeline_event.identifier in node_metadata:
node_meta: Optional[dict] = node_metadata[
timeline_event.identifier
]
context_stack.append(
ContextStackEntry(
"node", timeline_event.identifier, node_meta, tid
if ctx_entry.context_type == ContextType.FX_NODE:
node_meta: Optional[dict] = node_metadata[
timeline_event.identifier
]
context_stack.append(
ContextStackEntry(
ContextType.FX_NODE,
timeline_event.identifier,
node_meta,
tid,
)
)
if timeline_event.marker_type == ContextType.INDUCTOR_NODE:
# Look up stack traces for this kernel
# TODO: make a dictionary that maps from compiled key to stack traces dictionary
stack_traces = current_file_metadata.get(
timeline_event.identifier, []
)
if stack_traces:
# Store all stack traces as metadata
node_meta: Optional[dict] = {
"stack_trace": stack_traces,
"name": timeline_event.identifier,
}
context_stack.append(
ContextStackEntry(
ContextType.INDUCTOR_NODE,
timeline_event.identifier,
node_meta,
tid,
)
)
case "end":
# Pop from stack - search backwards to find matching context
@ -551,7 +625,10 @@ def map_recorded_events_to_aten_ops_with_stack_trace(traced_data):
for ctx_entry in reversed(context_stack):
# Only apply metadata from contexts with matching tid
if ctx_entry.tid == event_tid:
if ctx_entry.context_type == "node" and ctx_entry.metadata:
if (
ctx_entry.context_type == ContextType.FX_NODE
and ctx_entry.metadata
):
current_stack_trace = ctx_entry.metadata.get(
"stack_trace", "No model stack trace available"
)
@ -559,6 +636,19 @@ def map_recorded_events_to_aten_ops_with_stack_trace(traced_data):
# Do we want to only attach the stack trace of the lowest node or stack trace of all nodes
# if nodes are nested, e.g. in nested graph modules
break
elif (
ctx_entry.context_type == ContextType.INDUCTOR_NODE
and ctx_entry.metadata
):
# For inductor nodes, stack_trace is a list of traces
stack_traces_list = ctx_entry.metadata.get(
"stack_trace", []
)
if stack_traces_list:
# Store as a list - each trace gets its own entry
current_stack_trace = stack_traces_list
current_node_name = ctx_entry.metadata.get("name", "")
break
# Augment the event
if current_stack_trace or current_node_name:
@ -567,3 +657,81 @@ def map_recorded_events_to_aten_ops_with_stack_trace(traced_data):
args["stack_trace"] = current_stack_trace
if current_node_name:
args["node_name"] = current_node_name
import tempfile
import os
# Collect all events with stack traces and format them canonically
def _canonicalize_profiler_events(events):
"""
Extract and format all events with stack traces in a canonical way
for deterministic testing.
"""
events_with_traces = []
for event in events:
# Extract relevant fields
event_name = event.get("name", "")
node_name = event["args"].get("node_name", "")
stack_trace = event["args"].get("stack_trace", "")
if isinstance(stack_trace, list):
stack_trace = "\n".join(stack_trace)
# Get the last non-empty line of the stack trace
lines = [s.strip() for s in stack_trace.split("\n") if s.strip()]
stack_trace = lines[-1] if lines else ""
events_with_traces.append(
{
"event_name": event_name[:20],
"node_name": node_name,
"stack_trace": stack_trace,
"start_time": event.get("ts", 0),
}
)
# Sort by node_name for deterministic ordering
events_with_traces.sort(key=lambda x: x["start_time"])
# Format as a string
lines: list[str] = []
for evt in events_with_traces:
lines.append(
f"event={evt['event_name']} node={evt['node_name']} stack_trace={evt['stack_trace']}"
)
return "\n".join(lines)
def _enrich_profiler_traces(prof):
"""
Helper function to extract and augment profiler events with stack traces.
Args:
prof: A torch.profiler.profile object
Returns:
A string representing enriched events
"""
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
trace_file = f.name
try:
prof.export_chrome_trace(trace_file)
with open(trace_file) as f:
trace_data = json.load(f)
map_recorded_events_to_aten_ops_with_stack_trace(trace_data)
events = []
for event in trace_data["traceEvents"]:
if "args" in event and "stack_trace" in event["args"]:
events.append(event)
actual_traces = _canonicalize_profiler_events(events)
return actual_traces
finally:
if os.path.exists(trace_file):
os.remove(trace_file)