Compare commits

..

1 Commits

Author SHA1 Message Date
9728dec054 anntoation in dynamo 2025-10-03 15:35:49 -07:00
10 changed files with 124 additions and 288 deletions

View File

@ -5,7 +5,6 @@ import unittest
import torch
import torch.distributed as dist
import torch.fx.traceback as fx_traceback
from torch._dynamo.functional_export import _dynamo_graph_capture_for_export
from torch._functorch.aot_autograd import aot_export_joint_with_descriptors
from torch._functorch.partitioners import min_cut_rematerialization_partition
@ -26,6 +25,7 @@ from torch.testing._internal.common_utils import (
)
from torch.testing._internal.distributed._tensor.common_dtensor import MLPModule
from torch.testing._internal.distributed.fake_pg import FakeStore
import torch.fx.traceback as fx_traceback
class SimpleModel(torch.nn.Module):
@ -37,7 +37,6 @@ class SimpleModel(torch.nn.Module):
def forward(self, input):
return self.mlp_1(self.mlp_0(input))
class SimpleModelAnnotated(torch.nn.Module):
def __init__(self, device):
super().__init__()
@ -103,7 +102,7 @@ class DTensorExportTest(TestCase):
)
self.device_type = "cuda"
def _run_test(self, export_fn, test_annotation=False):
def _run_test(self, export_fn):
dp_degree = 2
tp_degree = self.world_size // dp_degree
@ -114,149 +113,45 @@ class DTensorExportTest(TestCase):
mesh_dim_names=["dp", "tp"],
)
model = None
if test_annotation:
model = SimpleModelAnnotated(self.device_type)
else:
model = SimpleModel(self.device_type)
parallelize_plan = {
"mlp_0.net1": ColwiseParallel(),
"mlp_0.net2": RowwiseParallel(),
"mlp_1.net1": ColwiseParallel(),
"mlp_1.net2": RowwiseParallel(),
}
tp_model = parallelize_module(model, mesh_2d["tp"], parallelize_plan)
for annotation in [True, False]:
model = None
if annotation:
model = SimpleModelAnnotated(self.device_type)
else:
model = SimpleModel(self.device_type)
parallelize_plan = {
"mlp_0.net1": ColwiseParallel(),
"mlp_0.net2": RowwiseParallel(),
"mlp_1.net1": ColwiseParallel(),
"mlp_1.net2": RowwiseParallel(),
}
tp_model = parallelize_module(model, mesh_2d["tp"], parallelize_plan)
inputs = torch.rand(20, 10, device=self.device_type)
inputs = distribute_tensor(inputs, mesh_2d["tp"], placements=[Replicate()])
inputs = torch.rand(20, 10, device=self.device_type)
inputs = distribute_tensor(inputs, mesh_2d["tp"], placements=[Replicate()])
joint_gm = export_fn(tp_model, inputs)
fw_gm, bw_gm = min_cut_rematerialization_partition(
joint_gm, None, num_fwd_outputs=1
)
self.assertTrue(
_count_op(joint_gm, torch.ops._c10d_functional.all_reduce.default),
3,
)
self.assertTrue(
_count_op(fw_gm, torch.ops._c10d_functional.all_reduce.default),
2,
)
self.assertTrue(
_count_op(bw_gm, torch.ops._c10d_functional.all_reduce.default),
1,
)
if test_annotation:
def has_tag(node):
return "custom" in node.meta and node.meta["custom"] == {"pp_stage": 0}
def marked_nodes(gm):
return [
node.name
for node in gm.graph.nodes
if has_tag(node) and node.op == "call_function"
]
def unmarked_nodes(gm):
return [
node.name
for node in gm.graph.nodes
if not has_tag(node) and node.op == "call_function"
]
marked_nodes_fw = [
"t",
"addmm",
"view",
"relu",
"view_1",
"t_1",
"div",
"addmm_1",
"all_reduce",
"wait_tensor",
"view_2",
"t_12",
]
unmarked_nodes_fw = [
"view_3",
"t_2",
"addmm_2",
"view_4",
"relu_1",
"view_5",
"t_3",
"div_1",
"addmm_3",
"all_reduce_1",
"wait_tensor_1",
"view_6",
"t_4",
"t_8",
]
marked_nodes_bw = [
"mm_4",
"t_13",
"view_1",
"mm_5",
"t_14",
"sum_3",
"view_9",
"t_15",
"detach",
"detach_1",
"detach_6",
"detach_7",
"threshold_backward_1",
"t_16",
"mm_6",
"t_17",
"sum_4",
"view_10",
"t_18",
]
unmarked_nodes_bw = [
"mm",
"t_5",
"view_5",
"mm_1",
"t_6",
"sum_1",
"view_7",
"t_7",
"detach_2",
"detach_3",
"detach_4",
"detach_5",
"threshold_backward",
"mm_2",
"t_9",
"mm_3",
"t_10",
"sum_2",
"view_8",
"t_11",
"all_reduce_2",
"wait_tensor_2",
]
self.assertEqual(marked_nodes(fw_gm), marked_nodes_fw)
self.assertEqual(unmarked_nodes(fw_gm), unmarked_nodes_fw)
self.assertEqual(marked_nodes(bw_gm), marked_nodes_bw)
self.assertEqual(unmarked_nodes(bw_gm), unmarked_nodes_bw)
self.assertEqual(
set(marked_nodes(joint_gm)), set(marked_nodes_fw + marked_nodes_bw)
joint_gm = export_fn(tp_model, inputs)
fw_gm, bw_gm = min_cut_rematerialization_partition(
joint_gm, None, num_fwd_outputs=1
)
self.assertEqual(
set(unmarked_nodes(joint_gm)),
set(unmarked_nodes_fw + unmarked_nodes_bw),
self.assertTrue(
_count_op(joint_gm, torch.ops._c10d_functional.all_reduce.default),
3,
)
self.assertTrue(
_count_op(fw_gm, torch.ops._c10d_functional.all_reduce.default),
2,
)
self.assertTrue(
_count_op(bw_gm, torch.ops._c10d_functional.all_reduce.default),
1,
)
if annotation:
for node in fw_gm.graph.nodes if node.op == "call_function":
self.assertTrue(node.meta["custom"], {"pp_stage": 0})
breakpoint()
@parametrize(
"export_fn",
@ -277,9 +172,6 @@ class DTensorExportTest(TestCase):
def test_strict_export_parallelize_module_with_dtensor_input(self):
self._run_test(strict_export_and_aot_export_joint_with_descriptors)
def test_annotate_aot_export_joint_with_descriptors_alone(self):
self._run_test(aot_export_joint_with_descriptors_alone, True)
instantiate_parametrized_tests(DTensorExportTest)

View File

@ -15228,11 +15228,7 @@ def forward(self, x):
test_serdes=True,
)
@testing.expectedFailureTrainingIRToRunDecomp
@testing.expectedFailureRetraceability
@testing.expectedFailureStrictV2
@testing.expectedFailureStrict # annotation needs to be handled in dynamo
@testing.expectedFailureSerDer
@testing.expectedFailureStrict # there is a name mapping bug in placeholder_naming_pass
def test_preserve_annotation(self):
class M(torch.nn.Module):
def forward(self, x):
@ -15251,22 +15247,17 @@ def forward(self, x):
ep = export(m, (torch.randn(10),))
for node in ep.graph.nodes:
if node.op in ("placeholder", "output"):
continue
if node.target == torch.ops.aten.add.Tensor:
if node.target in (torch.ops.aten.add.default, torch.ops.aten.add.Tensor):
self.assertTrue(node.meta["custom"], {"pp_stage": 0, "fdsp_bucket": 0})
elif node.target == torch.ops.aten.sub.Tensor:
if node.target in (torch.ops.aten.sub.default, torch.ops.aten.sub.Tensor):
self.assertTrue(node.meta["custom"], {"pp_stage": 0})
elif node.target == torch.ops.aten.mul.Tensor:
if node.target in (torch.ops.aten.mul.default, torch.ops.aten.mul.Tensor):
self.assertTrue(
node.meta["custom"],
{"pp_stage": 0, "cuda_stream": 2, "fsdp_bucket": 1},
)
elif node.target == torch.ops.aten.div.Tensor:
if "custom" in node.meta:
self.assertTrue(node.meta["custom"], {})
else:
raise AssertionError(f"Node not checked: {node}, {node.target}")
if node.target in (torch.ops.aten.div.default, torch.ops.aten.div.Tensor):
self.assertTrue(node.meta["custom"], {})
def test_dynamic_shapes_serdes_generic(self):
from torch._export.serde.dynamic_shapes import (

View File

@ -777,35 +777,28 @@ class inner_f(torch.nn.Module):
return y - 1
inputs = (torch.randn(4, 3),)
model = SimpleLinear()
for with_export in [False]: # TODO: make dynamo work for annotation
with ExitStack() as stack:
model = None
with fx_traceback.preserve_node_meta():
if with_export:
with torch._dynamo.config.patch(install_free_tensors=True):
# TODO: switch to use the official graph_capture API once it is ready
model = _dynamo_graph_capture_for_export(model)(*inputs)
else:
model = SimpleLinear()
for with_export in [True, False]:
if with_export:
with torch._dynamo.config.patch(install_free_tensors=True):
# TODO: switch to use the official graph_capture API once it is ready
model = _dynamo_graph_capture_for_export(model)(*inputs)
with fx_traceback.preserve_node_meta():
with ExitStack() as stack:
joint_with_descriptors = aot_export_joint_with_descriptors(
stack, model, inputs, decompositions={}
)
for node in joint_with_descriptors.graph_module.graph.nodes:
if node.op in ("placeholder", "output"):
continue
if node.target != torch.ops.aten.sub.Tensor and node.op not in (
"placeholder",
"output",
):
self.assertTrue(node.meta["custom"], {"pp_stage": 0})
elif node.target == torch.ops.aten.sub.Tensor:
if "custom" in node.meta:
self.assertTrue(node.meta.get("custom", {}), {})
else:
raise AssertionError(f"Node not checked: {node}, {node.target}")
for node in joint_with_descriptors.graph_module.graph.nodes:
if node.target != torch.ops.aten.sub.Tensor and node.op not in (
"placeholder",
"output",
):
self.assertTrue(node.meta["custom"], {"pp_stage": 0})
if node.target == torch.ops.aten.sub.Tensor:
self.assertTrue(node.meta.get("custom", {}), {})
if __name__ == "__main__":

View File

@ -8,6 +8,7 @@ import sympy
import torch
import torch.fx
import torch.fx.traceback as fx_traceback
import torch.utils._pytree as pytree
from torch._dynamo.convert_frame import CaptureOutput, fullgraph_capture, get_traced_fn
from torch._dynamo.eval_frame import argument_names
@ -308,7 +309,7 @@ class DynamoGraphTransformer(torch.fx.Transformer):
# Copy important metadata
if hasattr(result, "node") and result.node is not node:
for key in ["val", "example_value", "unbacked_bindings"]:
for key in ["val", "example_value", "unbacked_bindings", "custom"]:
if key in node.meta:
result.node.meta[key] = node.meta[key]
@ -458,6 +459,7 @@ def _dynamo_graph_capture_for_export(
get_metrics_context(),
dynamo_timed("fullgraph_capture"),
dynamo_config_ctx,
fx_traceback.preserve_node_meta(),
):
out = fullgraph_capture(
module_to_trace,

View File

@ -156,6 +156,8 @@ manual_torch_name_rule_map: dict[
],
] = {
"torch.fx.traceback.annotate": UserFunctionVariable,
"torch.fx.traceback._enter_annotation_context": UserFunctionVariable,
"torch.fx.traceback._exit_annotation_context": UserFunctionVariable,
"torch.onnx.is_in_onnx_export": TorchInGraphFunctionVariable,
"torch.onnx.operators.shape_as_tensor": TorchInGraphFunctionVariable,
"torch.overrides.is_tensor_like": TorchInGraphFunctionVariable,

View File

@ -505,6 +505,14 @@ class UserFunctionVariable(BaseUserFunctionVariable):
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
# Handle patch_dynamo_config call
if self.fn in (
torch.fx.traceback._enter_annotation_context,
torch.fx.traceback._exit_annotation_context,
):
args_const = [arg.as_python_constant() for arg in args]
self.fn(*args_const)
if self.fn is torch._dynamo.patch_dynamo_config:
try:
args_const = [arg.as_python_constant() for arg in args]

View File

@ -65,76 +65,3 @@ def get_node_context(node, num_nodes=2) -> str:
break
cur = cur.prev
return "\n".join(node_contexts[::-1])
def map_recorded_events_to_aten_ops_with_stack_trace(graph_module, traced_data):
"""
Maps recorded profiler events to their corresponding aten operations and adds stack traces.
Args:
graph_module: The FX GraphModule
traced_data: Json of profiler events from Chrome trace
Returns:
Dict mapping recorded event names to their aten operations with added stack traces
"""
trace_events = traced_data.get("traceEvents", [])
# Create a mapping from node name to node for easy lookup
node_map = {node.name: node for node in graph_module.graph.nodes}
# Find aten operation events
aten_events = [e for e in trace_events if e.get("cat") == "cpu_op"]
# Map recorded events to aten ops and add stack traces
event_mapping = {}
for recorded_event in trace_events:
if (recorded_event.get("cat") in ["cpu_op"] and
recorded_event.get("name", "").startswith("## ") and
recorded_event.get("name", "").endswith(" ##")):
# Extract node name from "## node_name ##"
node_name = recorded_event["name"][3:-3] # Remove "## " and " ##"
if node_name in node_map:
node = node_map[node_name]
# Find corresponding aten operations within this recorded event's time window
recorded_start = recorded_event["ts"]
recorded_end = recorded_start + recorded_event["dur"]
# Find aten ops that fall within this time window
corresponding_aten_ops = []
for aten_event in aten_events:
aten_start = aten_event["ts"]
aten_end = aten_start + aten_event["dur"]
# Check if aten event overlaps with recorded event
if (aten_start >= recorded_start and aten_start <= recorded_end) or \
(aten_end >= recorded_start and aten_end <= recorded_end) or \
(aten_start <= recorded_start and aten_end >= recorded_end):
corresponding_aten_ops.append(aten_event)
# Add stack trace to recorded event and aten ops
stack_trace = node.meta.get("stack_trace", "No stack trace available")
# Add stack trace to the recorded event
if "args" not in recorded_event:
recorded_event["args"] = {}
recorded_event["args"]["stack_trace"] = stack_trace
# Add stack trace to corresponding aten ops
for aten_op in corresponding_aten_ops:
if "args" not in aten_op:
aten_op["args"] = {}
aten_op["args"]["stack_trace"] = stack_trace
event_mapping[node_name] = {
"recorded_event": recorded_event,
"aten_operations": corresponding_aten_ops,
"node": node,
"stack_trace": stack_trace
}
return event_mapping

View File

@ -440,7 +440,6 @@ class CodeGen:
colored: bool = False,
# Render each argument on its own line
expanded_def: bool = False,
record_func: bool = False,
) -> PythonCode:
free_vars: list[str] = []
body: list[str] = []
@ -778,13 +777,8 @@ class CodeGen:
# node index, which will be deleted later
# after going through _body_transformer
body.append(f"# COUNTER: {i}\n")
do_record = record_func and node.op in ("call_function", "call_method", "call_module")
if do_record:
body.append(f"_rf_{node.name} = torch._C._profiler._RecordFunctionFast('## {node.name} ##'); _rf_{node.name}.__enter__()\n")
emit_node(node)
delete_unused_values(node)
if do_record:
body.append(f"_rf_{node.name}.__exit__(None, None, None)\n")
if len(body) == 0:
# If the Graph has no non-placeholder nodes, no lines for the body
@ -1215,9 +1209,6 @@ class Graph:
name = self._graph_namespace.create_name(candidate, None)
n = Node(self, name, op, target, args, kwargs, type_expr)
# print(name)
# breakpoint()
if (
self.owning_module is not None
and getattr(self.owning_module, "_create_node_hooks", None) is not None
@ -1642,7 +1633,6 @@ class Graph:
include_device: bool = False,
colored: bool = False,
expanded_def: bool = False,
record_func: bool = False,
) -> PythonCode:
"""
Turn this ``Graph`` into valid Python code.
@ -1710,7 +1700,6 @@ class Graph:
include_device=include_device,
colored=colored,
expanded_def=expanded_def,
record_func=record_func,
)
def _python_code(
@ -1723,7 +1712,6 @@ class Graph:
include_device: bool = False,
colored: bool = False,
expanded_def: bool = False,
record_func: bool = False,
) -> PythonCode:
return self._codegen._gen_python_code(
self.nodes,
@ -1734,7 +1722,6 @@ class Graph:
include_device=include_device,
colored=colored,
expanded_def=expanded_def,
record_func=record_func,
)
def __str__(self) -> str:

View File

@ -161,7 +161,6 @@ class Interpreter:
delay=0,
)
print("running inside interpreter")
for node in self.graph.nodes:
pbar.update(1)
if node in self.env:

View File

@ -242,6 +242,52 @@ def set_stack_trace(stack: list[str]):
current_meta["stack_trace"] = "".join(stack)
@compatibility(is_backward_compatible=False)
def _enter_annotation_context(annotation_dict: dict, context_state: dict):
"""
Apply annotations to the current meta context.
Args:
annotation_dict (dict): The annotations to apply.
context_state (dict): Dictionary to store the context state for restoration.
Returns:
dict: The context state that was updated.
"""
global current_meta
context_state["has_custom"] = "custom" in current_meta
context_state["old_custom"] = {}
# cannot use `old_custom = copy.copy(current_meta.get("custom", {}))` here,
# as dynamo doesn't support copy.copy()
for k, v in current_meta.get("custom", {}).items():
context_state["old_custom"][k] = v # noqa: PERF403
if not context_state["has_custom"]:
current_meta["custom"] = {}
# Update with all key-value pairs from the input dict
current_meta["custom"].update(annotation_dict)
return context_state
@compatibility(is_backward_compatible=False)
def _exit_annotation_context(context_state: dict):
"""
Restore the original meta context state.
Args:
context_state (dict): The context state to restore from.
"""
global current_meta
if context_state["has_custom"]:
# Restore the original custom dict
current_meta["custom"] = context_state["old_custom"]
else:
del current_meta["custom"]
@compatibility(is_backward_compatible=False)
@contextmanager
def annotate(annotation_dict: dict):
@ -267,6 +313,7 @@ def annotate(annotation_dict: dict):
Args:
annotation_dict (dict): A dictionary of custom key-value pairs to inject
into the FX trace metadata.
If you want to dynamo trace to work, the `annotation_dict` can only contain constants
Example:
>>> with annotate({"source": "custom_pass", "tag": 42}):
@ -276,26 +323,14 @@ def annotate(annotation_dict: dict):
global current_meta
has_custom = "custom" in current_meta
old_custom = {}
# cannot use `old_custom = copy.copy(current_meta.get("custom", {}))` here,
# as dynamo doesn't support copy.copy()
for k, v in current_meta.get("custom", {}).items():
old_custom[k] = v # noqa: PERF403
# Note: we need the two functions below for dynamo to intercept them.
# If you want to add any logic, please add to the _enter and _exit functions
old_custom: dict[str, Any] = {}
try:
if not has_custom:
current_meta["custom"] = {}
# Update with all key-value pairs from the input dict
current_meta["custom"].update(annotation_dict)
_enter_annotation_context(annotation_dict, old_custom)
yield
finally:
if has_custom:
# Restore the original custom dict
current_meta["custom"] = old_custom
else:
del current_meta["custom"]
_exit_annotation_context(old_custom)
@compatibility(is_backward_compatible=False)