Compare commits

...

3 Commits

Author SHA1 Message Date
73cdd67a96 codegen trace 2025-10-06 09:14:31 -07:00
c8a4f7b64d add test 2025-10-03 17:24:09 -07:00
27352f523a Support propagating custom meta field to backward graph nodes 2025-10-01 09:44:35 -07:00
8 changed files with 260 additions and 28 deletions

View File

@ -5,6 +5,7 @@ import unittest
import torch
import torch.distributed as dist
import torch.fx.traceback as fx_traceback
from torch._dynamo.functional_export import _dynamo_graph_capture_for_export
from torch._functorch.aot_autograd import aot_export_joint_with_descriptors
from torch._functorch.partitioners import min_cut_rematerialization_partition
@ -37,6 +38,18 @@ class SimpleModel(torch.nn.Module):
return self.mlp_1(self.mlp_0(input))
class SimpleModelAnnotated(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.mlp_0 = MLPModule(device)
self.mlp_1 = MLPModule(device)
def forward(self, input):
with fx_traceback.annotate({"pp_stage": 0}):
x = self.mlp_0(input)
return self.mlp_1(x)
def strict_export_and_aot_export_joint_with_descriptors(model, inputs):
# needed for stric export
torch.utils._pytree.register_constant(DTensorSpec)
@ -90,7 +103,7 @@ class DTensorExportTest(TestCase):
)
self.device_type = "cuda"
def _run_test(self, export_fn):
def _run_test(self, export_fn, test_annotation=False):
dp_degree = 2
tp_degree = self.world_size // dp_degree
@ -101,7 +114,11 @@ class DTensorExportTest(TestCase):
mesh_dim_names=["dp", "tp"],
)
model = SimpleModel(self.device_type)
model = None
if test_annotation:
model = SimpleModelAnnotated(self.device_type)
else:
model = SimpleModel(self.device_type)
parallelize_plan = {
"mlp_0.net1": ColwiseParallel(),
"mlp_0.net2": RowwiseParallel(),
@ -131,6 +148,116 @@ class DTensorExportTest(TestCase):
1,
)
if test_annotation:
def has_tag(node):
return "custom" in node.meta and node.meta["custom"] == {"pp_stage": 0}
def marked_nodes(gm):
return [
node.name
for node in gm.graph.nodes
if has_tag(node) and node.op == "call_function"
]
def unmarked_nodes(gm):
return [
node.name
for node in gm.graph.nodes
if not has_tag(node) and node.op == "call_function"
]
marked_nodes_fw = [
"t",
"addmm",
"view",
"relu",
"view_1",
"t_1",
"div",
"addmm_1",
"all_reduce",
"wait_tensor",
"view_2",
"t_12",
]
unmarked_nodes_fw = [
"view_3",
"t_2",
"addmm_2",
"view_4",
"relu_1",
"view_5",
"t_3",
"div_1",
"addmm_3",
"all_reduce_1",
"wait_tensor_1",
"view_6",
"t_4",
"t_8",
]
marked_nodes_bw = [
"mm_4",
"t_13",
"view_1",
"mm_5",
"t_14",
"sum_3",
"view_9",
"t_15",
"detach",
"detach_1",
"detach_6",
"detach_7",
"threshold_backward_1",
"t_16",
"mm_6",
"t_17",
"sum_4",
"view_10",
"t_18",
]
unmarked_nodes_bw = [
"mm",
"t_5",
"view_5",
"mm_1",
"t_6",
"sum_1",
"view_7",
"t_7",
"detach_2",
"detach_3",
"detach_4",
"detach_5",
"threshold_backward",
"mm_2",
"t_9",
"mm_3",
"t_10",
"sum_2",
"view_8",
"t_11",
"all_reduce_2",
"wait_tensor_2",
]
self.assertEqual(marked_nodes(fw_gm), marked_nodes_fw)
self.assertEqual(unmarked_nodes(fw_gm), unmarked_nodes_fw)
self.assertEqual(marked_nodes(bw_gm), marked_nodes_bw)
self.assertEqual(unmarked_nodes(bw_gm), unmarked_nodes_bw)
self.assertEqual(
set(marked_nodes(joint_gm)), set(marked_nodes_fw + marked_nodes_bw)
)
self.assertEqual(
set(unmarked_nodes(joint_gm)),
set(unmarked_nodes_fw + unmarked_nodes_bw),
)
@parametrize(
"export_fn",
[
@ -150,6 +277,9 @@ class DTensorExportTest(TestCase):
def test_strict_export_parallelize_module_with_dtensor_input(self):
self._run_test(strict_export_and_aot_export_joint_with_descriptors)
def test_annotate_aot_export_joint_with_descriptors_alone(self):
self._run_test(aot_export_joint_with_descriptors_alone, True)
instantiate_parametrized_tests(DTensorExportTest)

View File

@ -910,7 +910,7 @@ SeqNr|OrigAten|SrcFn|FwdSrcFn
7|aten.view.default||l__self___fc1
6|aten.t.default||l__self___fc1
5|aten.view.default||l__self___fc1
4|aten.view.default||
4|aten.view.default||flatten
2|aten.detach.default||l__self___relu1
2|aten.detach.default||l__self___relu1
2|aten.threshold_backward.default||l__self___relu1

View File

@ -15228,6 +15228,11 @@ def forward(self, x):
test_serdes=True,
)
@testing.expectedFailureTrainingIRToRunDecomp
@testing.expectedFailureRetraceability
@testing.expectedFailureStrictV2
@testing.expectedFailureStrict # annotation needs to be handled in dynamo
@testing.expectedFailureSerDer
def test_preserve_annotation(self):
class M(torch.nn.Module):
def forward(self, x):
@ -15246,17 +15251,22 @@ def forward(self, x):
ep = export(m, (torch.randn(10),))
for node in ep.graph.nodes:
if node.target == torch.ops.aten.add.default:
if node.op in ("placeholder", "output"):
continue
if node.target == torch.ops.aten.add.Tensor:
self.assertTrue(node.meta["custom"], {"pp_stage": 0, "fdsp_bucket": 0})
if node.target == torch.ops.aten.sub.default:
elif node.target == torch.ops.aten.sub.Tensor:
self.assertTrue(node.meta["custom"], {"pp_stage": 0})
if node.target == torch.ops.aten.mul.default:
elif node.target == torch.ops.aten.mul.Tensor:
self.assertTrue(
node.meta["custom"],
{"pp_stage": 0, "cuda_stream": 2, "fsdp_bucket": 1},
)
if node.target == torch.ops.aten.div.default:
self.assertTrue(node.meta["custom"], {})
elif node.target == torch.ops.aten.div.Tensor:
if "custom" in node.meta:
self.assertTrue(node.meta["custom"], {})
else:
raise AssertionError(f"Node not checked: {node}, {node.target}")
def test_dynamic_shapes_serdes_generic(self):
from torch._export.serde.dynamic_shapes import (

View File

@ -13,6 +13,7 @@ import torch.fx.traceback as fx_traceback
import torch.nn as nn
import torch.utils._pytree as pytree
from torch._decomp import decomposition_table
from torch._dynamo.functional_export import _dynamo_graph_capture_for_export
from torch._dynamo.testing import normalize_gm
from torch._functorch._aot_autograd.descriptors import (
BufferAOTInput,
@ -777,36 +778,34 @@ class inner_f(torch.nn.Module):
inputs = (torch.randn(4, 3),)
for with_export in [True, False]:
for with_export in [False]: # TODO: make dynamo work for annotation
with ExitStack() as stack:
model = None
with fx_traceback.preserve_node_meta():
if with_export:
ep = torch.export.export(SimpleLinear(), inputs)
model = ep.module()
with torch._dynamo.config.patch(install_free_tensors=True):
# TODO: switch to use the official graph_capture API once it is ready
model = _dynamo_graph_capture_for_export(model)(*inputs)
else:
model = SimpleLinear()
joint_with_descriptors = aot_export_joint_with_descriptors(
stack, model, inputs, decompositions=decomposition_table
stack, model, inputs, decompositions={}
)
for node in joint_with_descriptors.graph_module.graph.nodes:
if (
node.target
in (
torch.ops.prims.transpose.default,
torch.ops.aten.mm.default,
torch.ops.prims.mul.default,
torch.ops.prims.broadcast_in_dim.default,
torch.ops.prims.add.default,
)
# TODO: add annotation to backward graph nodes
and node.meta.get("partitioner_tag") != "is_backward"
if node.op in ("placeholder", "output"):
continue
if node.target != torch.ops.aten.sub.Tensor and node.op not in (
"placeholder",
"output",
):
self.assertTrue(node.meta["custom"], {"pp_stage": 0})
if node.target == torch.ops.aten.sub.default:
self.assertTrue(node.meta.get("custom", {}), {})
elif node.target == torch.ops.aten.sub.Tensor:
if "custom" in node.meta:
self.assertTrue(node.meta.get("custom", {}), {})
else:
raise AssertionError(f"Node not checked: {node}, {node.target}")
if __name__ == "__main__":

View File

@ -420,14 +420,18 @@ def copy_fwd_metadata_to_bw_nodes(fx_g):
# the descendants of graph inputs corresponding to fwd inputs, didn't
# seem obvious at first glance on how to partition graph inputs into
# fwd vs bwd without relying on string names.
return "nn_module_stack" in node.meta and "seq_nr" in node.meta
return (
node.meta.get("partitioner_tag") != "is_backward" and "seq_nr" in node.meta
)
def _is_backward_node_with_seq_nr(node):
# For now, assume that if nn_module_stack_metadata is not populated,
# this node is from the backward. Ignore nodes without `seq_nr`.
# TODO(future): there is likely a less brittle way to do this, same
# as with the forward.
return ("nn_module_stack" not in node.meta) and "seq_nr" in node.meta
return (
node.meta.get("partitioner_tag") == "is_backward" and "seq_nr" in node.meta
)
fwd_seq_nr_to_node = {}
for node in fx_g.graph.nodes:
@ -447,8 +451,10 @@ def copy_fwd_metadata_to_bw_nodes(fx_g):
# fwd_node should always exist, but handle non-existence just in case
fwd_node = fwd_seq_nr_to_node.get(node.meta["seq_nr"])
if fwd_node is not None:
node.meta["fwd_nn_module_stack"] = fwd_node.meta["nn_module_stack"]
node.meta["fwd_nn_module_stack"] = fwd_node.meta.get("nn_module_stack")
node.meta["fwd_source_fn_stack"] = fwd_node.meta.get("source_fn_stack")
# TODO: better to change to a specific field of custom?
node.meta["custom"] = fwd_node.meta.get("custom")
def register_buffer_assignment_hook(mod, assigned_buffers):

View File

@ -65,3 +65,76 @@ def get_node_context(node, num_nodes=2) -> str:
break
cur = cur.prev
return "\n".join(node_contexts[::-1])
def map_recorded_events_to_aten_ops_with_stack_trace(graph_module, traced_data):
"""
Maps recorded profiler events to their corresponding aten operations and adds stack traces.
Args:
graph_module: The FX GraphModule
traced_data: Json of profiler events from Chrome trace
Returns:
Dict mapping recorded event names to their aten operations with added stack traces
"""
trace_events = traced_data.get("traceEvents", [])
# Create a mapping from node name to node for easy lookup
node_map = {node.name: node for node in graph_module.graph.nodes}
# Find aten operation events
aten_events = [e for e in trace_events if e.get("cat") == "cpu_op"]
# Map recorded events to aten ops and add stack traces
event_mapping = {}
for recorded_event in trace_events:
if (recorded_event.get("cat") in ["cpu_op"] and
recorded_event.get("name", "").startswith("## ") and
recorded_event.get("name", "").endswith(" ##")):
# Extract node name from "## node_name ##"
node_name = recorded_event["name"][3:-3] # Remove "## " and " ##"
if node_name in node_map:
node = node_map[node_name]
# Find corresponding aten operations within this recorded event's time window
recorded_start = recorded_event["ts"]
recorded_end = recorded_start + recorded_event["dur"]
# Find aten ops that fall within this time window
corresponding_aten_ops = []
for aten_event in aten_events:
aten_start = aten_event["ts"]
aten_end = aten_start + aten_event["dur"]
# Check if aten event overlaps with recorded event
if (aten_start >= recorded_start and aten_start <= recorded_end) or \
(aten_end >= recorded_start and aten_end <= recorded_end) or \
(aten_start <= recorded_start and aten_end >= recorded_end):
corresponding_aten_ops.append(aten_event)
# Add stack trace to recorded event and aten ops
stack_trace = node.meta.get("stack_trace", "No stack trace available")
# Add stack trace to the recorded event
if "args" not in recorded_event:
recorded_event["args"] = {}
recorded_event["args"]["stack_trace"] = stack_trace
# Add stack trace to corresponding aten ops
for aten_op in corresponding_aten_ops:
if "args" not in aten_op:
aten_op["args"] = {}
aten_op["args"]["stack_trace"] = stack_trace
event_mapping[node_name] = {
"recorded_event": recorded_event,
"aten_operations": corresponding_aten_ops,
"node": node,
"stack_trace": stack_trace
}
return event_mapping

View File

@ -440,6 +440,7 @@ class CodeGen:
colored: bool = False,
# Render each argument on its own line
expanded_def: bool = False,
record_func: bool = False,
) -> PythonCode:
free_vars: list[str] = []
body: list[str] = []
@ -777,8 +778,13 @@ class CodeGen:
# node index, which will be deleted later
# after going through _body_transformer
body.append(f"# COUNTER: {i}\n")
do_record = record_func and node.op in ("call_function", "call_method", "call_module")
if do_record:
body.append(f"_rf_{node.name} = torch._C._profiler._RecordFunctionFast('## {node.name} ##'); _rf_{node.name}.__enter__()\n")
emit_node(node)
delete_unused_values(node)
if do_record:
body.append(f"_rf_{node.name}.__exit__(None, None, None)\n")
if len(body) == 0:
# If the Graph has no non-placeholder nodes, no lines for the body
@ -1209,6 +1215,9 @@ class Graph:
name = self._graph_namespace.create_name(candidate, None)
n = Node(self, name, op, target, args, kwargs, type_expr)
# print(name)
# breakpoint()
if (
self.owning_module is not None
and getattr(self.owning_module, "_create_node_hooks", None) is not None
@ -1633,6 +1642,7 @@ class Graph:
include_device: bool = False,
colored: bool = False,
expanded_def: bool = False,
record_func: bool = False,
) -> PythonCode:
"""
Turn this ``Graph`` into valid Python code.
@ -1700,6 +1710,7 @@ class Graph:
include_device=include_device,
colored=colored,
expanded_def=expanded_def,
record_func=record_func,
)
def _python_code(
@ -1712,6 +1723,7 @@ class Graph:
include_device: bool = False,
colored: bool = False,
expanded_def: bool = False,
record_func: bool = False,
) -> PythonCode:
return self._codegen._gen_python_code(
self.nodes,
@ -1722,6 +1734,7 @@ class Graph:
include_device=include_device,
colored=colored,
expanded_def=expanded_def,
record_func=record_func,
)
def __str__(self) -> str:

View File

@ -161,6 +161,7 @@ class Interpreter:
delay=0,
)
print("running inside interpreter")
for node in self.graph.nodes:
pbar.update(1)
if node in self.env: