Compare commits

...

1 Commits

Author SHA1 Message Date
23c5c1df96 codegen trace 2025-10-30 16:14:36 -07:00
3 changed files with 87 additions and 0 deletions

View File

@ -65,3 +65,76 @@ def get_node_context(node, num_nodes=2) -> str:
break
cur = cur.prev
return "\n".join(node_contexts[::-1])
def map_recorded_events_to_aten_ops_with_stack_trace(graph_module, traced_data):
"""
Maps recorded profiler events to their corresponding aten operations and adds stack traces.
Args:
graph_module: The FX GraphModule
traced_data: Json of profiler events from Chrome trace
Returns:
Dict mapping recorded event names to their aten operations with added stack traces
"""
trace_events = traced_data.get("traceEvents", [])
# Create a mapping from node name to node for easy lookup
node_map = {node.name: node for node in graph_module.graph.nodes}
# Find aten operation events
aten_events = [e for e in trace_events if e.get("cat") == "cpu_op"]
# Map recorded events to aten ops and add stack traces
event_mapping = {}
for recorded_event in trace_events:
if (recorded_event.get("cat") in ["cpu_op"] and
recorded_event.get("name", "").startswith("## ") and
recorded_event.get("name", "").endswith(" ##")):
# Extract node name from "## node_name ##"
node_name = recorded_event["name"][3:-3] # Remove "## " and " ##"
if node_name in node_map:
node = node_map[node_name]
# Find corresponding aten operations within this recorded event's time window
recorded_start = recorded_event["ts"]
recorded_end = recorded_start + recorded_event["dur"]
# Find aten ops that fall within this time window
corresponding_aten_ops = []
for aten_event in aten_events:
aten_start = aten_event["ts"]
aten_end = aten_start + aten_event["dur"]
# Check if aten event overlaps with recorded event
if (aten_start >= recorded_start and aten_start <= recorded_end) or \
(aten_end >= recorded_start and aten_end <= recorded_end) or \
(aten_start <= recorded_start and aten_end >= recorded_end):
corresponding_aten_ops.append(aten_event)
# Add stack trace to recorded event and aten ops
stack_trace = node.meta.get("stack_trace", "No stack trace available")
# Add stack trace to the recorded event
if "args" not in recorded_event:
recorded_event["args"] = {}
recorded_event["args"]["stack_trace"] = stack_trace
# Add stack trace to corresponding aten ops
for aten_op in corresponding_aten_ops:
if "args" not in aten_op:
aten_op["args"] = {}
aten_op["args"]["stack_trace"] = stack_trace
event_mapping[node_name] = {
"recorded_event": recorded_event,
"aten_operations": corresponding_aten_ops,
"node": node,
"stack_trace": stack_trace
}
return event_mapping

View File

@ -440,6 +440,7 @@ class CodeGen:
colored: bool = False,
# Render each argument on its own line
expanded_def: bool = False,
record_func: bool = False,
) -> PythonCode:
free_vars: list[str] = []
body: list[str] = []
@ -790,8 +791,13 @@ class CodeGen:
# node index, which will be deleted later
# after going through _body_transformer
body.append(f"# COUNTER: {i}\n")
do_record = record_func and node.op in ("call_function", "call_method", "call_module")
if do_record:
body.append(f"_rf_{node.name} = torch._C._profiler._RecordFunctionFast('## {node.name} ##'); _rf_{node.name}.__enter__()\n")
emit_node(node)
delete_unused_values(node)
if do_record:
body.append(f"_rf_{node.name}.__exit__(None, None, None)\n")
if len(body) == 0:
# If the Graph has no non-placeholder nodes, no lines for the body
@ -1260,6 +1266,9 @@ class Graph:
name = self._graph_namespace.create_name(candidate, None)
n = Node(self, name, op, target, args, kwargs, type_expr)
# print(name)
# breakpoint()
if (
self.owning_module is not None
and getattr(self.owning_module, "_create_node_hooks", None) is not None
@ -1684,6 +1693,7 @@ class Graph:
include_device: bool = False,
colored: bool = False,
expanded_def: bool = False,
record_func: bool = False,
) -> PythonCode:
"""
Turn this ``Graph`` into valid Python code.
@ -1751,6 +1761,7 @@ class Graph:
include_device=include_device,
colored=colored,
expanded_def=expanded_def,
record_func=record_func,
)
def _python_code(
@ -1763,6 +1774,7 @@ class Graph:
include_device: bool = False,
colored: bool = False,
expanded_def: bool = False,
record_func: bool = False,
) -> PythonCode:
return self._codegen._gen_python_code(
self.nodes,
@ -1773,6 +1785,7 @@ class Graph:
include_device=include_device,
colored=colored,
expanded_def=expanded_def,
record_func=record_func,
)
def __str__(self) -> str:

View File

@ -161,6 +161,7 @@ class Interpreter:
delay=0,
)
print("running inside interpreter")
for node in self.graph.nodes:
pbar.update(1)
if node in self.env: