Compare commits

...

2 Commits

Author SHA1 Message Date
9728dec054 anntoation in dynamo 2025-10-03 15:35:49 -07:00
27352f523a Support propagating custom meta field to backward graph nodes 2025-10-01 09:44:35 -07:00
9 changed files with 145 additions and 77 deletions

View File

@ -25,6 +25,7 @@ from torch.testing._internal.common_utils import (
)
from torch.testing._internal.distributed._tensor.common_dtensor import MLPModule
from torch.testing._internal.distributed.fake_pg import FakeStore
import torch.fx.traceback as fx_traceback
class SimpleModel(torch.nn.Module):
@ -36,6 +37,17 @@ class SimpleModel(torch.nn.Module):
def forward(self, input):
return self.mlp_1(self.mlp_0(input))
class SimpleModelAnnotated(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.mlp_0 = MLPModule(device)
self.mlp_1 = MLPModule(device)
def forward(self, input):
with fx_traceback.annotate({"pp_stage": 0}):
x = self.mlp_0(input)
return self.mlp_1(x)
def strict_export_and_aot_export_joint_with_descriptors(model, inputs):
# needed for stric export
@ -101,35 +113,45 @@ class DTensorExportTest(TestCase):
mesh_dim_names=["dp", "tp"],
)
model = SimpleModel(self.device_type)
parallelize_plan = {
"mlp_0.net1": ColwiseParallel(),
"mlp_0.net2": RowwiseParallel(),
"mlp_1.net1": ColwiseParallel(),
"mlp_1.net2": RowwiseParallel(),
}
tp_model = parallelize_module(model, mesh_2d["tp"], parallelize_plan)
for annotation in [True, False]:
model = None
if annotation:
model = SimpleModelAnnotated(self.device_type)
else:
model = SimpleModel(self.device_type)
parallelize_plan = {
"mlp_0.net1": ColwiseParallel(),
"mlp_0.net2": RowwiseParallel(),
"mlp_1.net1": ColwiseParallel(),
"mlp_1.net2": RowwiseParallel(),
}
tp_model = parallelize_module(model, mesh_2d["tp"], parallelize_plan)
inputs = torch.rand(20, 10, device=self.device_type)
inputs = distribute_tensor(inputs, mesh_2d["tp"], placements=[Replicate()])
inputs = torch.rand(20, 10, device=self.device_type)
inputs = distribute_tensor(inputs, mesh_2d["tp"], placements=[Replicate()])
joint_gm = export_fn(tp_model, inputs)
fw_gm, bw_gm = min_cut_rematerialization_partition(
joint_gm, None, num_fwd_outputs=1
)
joint_gm = export_fn(tp_model, inputs)
fw_gm, bw_gm = min_cut_rematerialization_partition(
joint_gm, None, num_fwd_outputs=1
)
self.assertTrue(
_count_op(joint_gm, torch.ops._c10d_functional.all_reduce.default),
3,
)
self.assertTrue(
_count_op(fw_gm, torch.ops._c10d_functional.all_reduce.default),
2,
)
self.assertTrue(
_count_op(bw_gm, torch.ops._c10d_functional.all_reduce.default),
1,
)
self.assertTrue(
_count_op(joint_gm, torch.ops._c10d_functional.all_reduce.default),
3,
)
self.assertTrue(
_count_op(fw_gm, torch.ops._c10d_functional.all_reduce.default),
2,
)
self.assertTrue(
_count_op(bw_gm, torch.ops._c10d_functional.all_reduce.default),
1,
)
if annotation:
for node in fw_gm.graph.nodes if node.op == "call_function":
self.assertTrue(node.meta["custom"], {"pp_stage": 0})
breakpoint()
@parametrize(
"export_fn",

View File

@ -910,7 +910,7 @@ SeqNr|OrigAten|SrcFn|FwdSrcFn
7|aten.view.default||l__self___fc1
6|aten.t.default||l__self___fc1
5|aten.view.default||l__self___fc1
4|aten.view.default||
4|aten.view.default||flatten
2|aten.detach.default||l__self___relu1
2|aten.detach.default||l__self___relu1
2|aten.threshold_backward.default||l__self___relu1

View File

@ -15228,6 +15228,7 @@ def forward(self, x):
test_serdes=True,
)
@testing.expectedFailureStrict # there is a name mapping bug in placeholder_naming_pass
def test_preserve_annotation(self):
class M(torch.nn.Module):
def forward(self, x):
@ -15246,16 +15247,16 @@ def forward(self, x):
ep = export(m, (torch.randn(10),))
for node in ep.graph.nodes:
if node.target == torch.ops.aten.add.default:
if node.target in (torch.ops.aten.add.default, torch.ops.aten.add.Tensor):
self.assertTrue(node.meta["custom"], {"pp_stage": 0, "fdsp_bucket": 0})
if node.target == torch.ops.aten.sub.default:
if node.target in (torch.ops.aten.sub.default, torch.ops.aten.sub.Tensor):
self.assertTrue(node.meta["custom"], {"pp_stage": 0})
if node.target == torch.ops.aten.mul.default:
if node.target in (torch.ops.aten.mul.default, torch.ops.aten.mul.Tensor):
self.assertTrue(
node.meta["custom"],
{"pp_stage": 0, "cuda_stream": 2, "fsdp_bucket": 1},
)
if node.target == torch.ops.aten.div.default:
if node.target in (torch.ops.aten.div.default, torch.ops.aten.div.Tensor):
self.assertTrue(node.meta["custom"], {})
def test_dynamic_shapes_serdes_generic(self):

View File

@ -13,6 +13,7 @@ import torch.fx.traceback as fx_traceback
import torch.nn as nn
import torch.utils._pytree as pytree
from torch._decomp import decomposition_table
from torch._dynamo.functional_export import _dynamo_graph_capture_for_export
from torch._dynamo.testing import normalize_gm
from torch._functorch._aot_autograd.descriptors import (
BufferAOTInput,
@ -776,37 +777,28 @@ class inner_f(torch.nn.Module):
return y - 1
inputs = (torch.randn(4, 3),)
model = SimpleLinear()
for with_export in [True, False]:
with ExitStack() as stack:
model = None
with fx_traceback.preserve_node_meta():
if with_export:
ep = torch.export.export(SimpleLinear(), inputs)
model = ep.module()
else:
model = SimpleLinear()
if with_export:
with torch._dynamo.config.patch(install_free_tensors=True):
# TODO: switch to use the official graph_capture API once it is ready
model = _dynamo_graph_capture_for_export(model)(*inputs)
with fx_traceback.preserve_node_meta():
with ExitStack() as stack:
joint_with_descriptors = aot_export_joint_with_descriptors(
stack, model, inputs, decompositions=decomposition_table
stack, model, inputs, decompositions={}
)
for node in joint_with_descriptors.graph_module.graph.nodes:
if (
node.target
in (
torch.ops.prims.transpose.default,
torch.ops.aten.mm.default,
torch.ops.prims.mul.default,
torch.ops.prims.broadcast_in_dim.default,
torch.ops.prims.add.default,
)
# TODO: add annotation to backward graph nodes
and node.meta.get("partitioner_tag") != "is_backward"
):
self.assertTrue(node.meta["custom"], {"pp_stage": 0})
if node.target == torch.ops.aten.sub.default:
self.assertTrue(node.meta.get("custom", {}), {})
for node in joint_with_descriptors.graph_module.graph.nodes:
if node.target != torch.ops.aten.sub.Tensor and node.op not in (
"placeholder",
"output",
):
self.assertTrue(node.meta["custom"], {"pp_stage": 0})
if node.target == torch.ops.aten.sub.Tensor:
self.assertTrue(node.meta.get("custom", {}), {})
if __name__ == "__main__":

View File

@ -8,6 +8,7 @@ import sympy
import torch
import torch.fx
import torch.fx.traceback as fx_traceback
import torch.utils._pytree as pytree
from torch._dynamo.convert_frame import CaptureOutput, fullgraph_capture, get_traced_fn
from torch._dynamo.eval_frame import argument_names
@ -308,7 +309,7 @@ class DynamoGraphTransformer(torch.fx.Transformer):
# Copy important metadata
if hasattr(result, "node") and result.node is not node:
for key in ["val", "example_value", "unbacked_bindings"]:
for key in ["val", "example_value", "unbacked_bindings", "custom"]:
if key in node.meta:
result.node.meta[key] = node.meta[key]
@ -458,6 +459,7 @@ def _dynamo_graph_capture_for_export(
get_metrics_context(),
dynamo_timed("fullgraph_capture"),
dynamo_config_ctx,
fx_traceback.preserve_node_meta(),
):
out = fullgraph_capture(
module_to_trace,

View File

@ -156,6 +156,8 @@ manual_torch_name_rule_map: dict[
],
] = {
"torch.fx.traceback.annotate": UserFunctionVariable,
"torch.fx.traceback._enter_annotation_context": UserFunctionVariable,
"torch.fx.traceback._exit_annotation_context": UserFunctionVariable,
"torch.onnx.is_in_onnx_export": TorchInGraphFunctionVariable,
"torch.onnx.operators.shape_as_tensor": TorchInGraphFunctionVariable,
"torch.overrides.is_tensor_like": TorchInGraphFunctionVariable,

View File

@ -505,6 +505,14 @@ class UserFunctionVariable(BaseUserFunctionVariable):
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
# Handle patch_dynamo_config call
if self.fn in (
torch.fx.traceback._enter_annotation_context,
torch.fx.traceback._exit_annotation_context,
):
args_const = [arg.as_python_constant() for arg in args]
self.fn(*args_const)
if self.fn is torch._dynamo.patch_dynamo_config:
try:
args_const = [arg.as_python_constant() for arg in args]

View File

@ -420,14 +420,18 @@ def copy_fwd_metadata_to_bw_nodes(fx_g):
# the descendants of graph inputs corresponding to fwd inputs, didn't
# seem obvious at first glance on how to partition graph inputs into
# fwd vs bwd without relying on string names.
return "nn_module_stack" in node.meta and "seq_nr" in node.meta
return (
node.meta.get("partitioner_tag") != "is_backward" and "seq_nr" in node.meta
)
def _is_backward_node_with_seq_nr(node):
# For now, assume that if nn_module_stack_metadata is not populated,
# this node is from the backward. Ignore nodes without `seq_nr`.
# TODO(future): there is likely a less brittle way to do this, same
# as with the forward.
return ("nn_module_stack" not in node.meta) and "seq_nr" in node.meta
return (
node.meta.get("partitioner_tag") == "is_backward" and "seq_nr" in node.meta
)
fwd_seq_nr_to_node = {}
for node in fx_g.graph.nodes:
@ -447,8 +451,10 @@ def copy_fwd_metadata_to_bw_nodes(fx_g):
# fwd_node should always exist, but handle non-existence just in case
fwd_node = fwd_seq_nr_to_node.get(node.meta["seq_nr"])
if fwd_node is not None:
node.meta["fwd_nn_module_stack"] = fwd_node.meta["nn_module_stack"]
node.meta["fwd_nn_module_stack"] = fwd_node.meta.get("nn_module_stack")
node.meta["fwd_source_fn_stack"] = fwd_node.meta.get("source_fn_stack")
# TODO: better to change to a specific field of custom?
node.meta["custom"] = fwd_node.meta.get("custom")
def register_buffer_assignment_hook(mod, assigned_buffers):

View File

@ -242,6 +242,52 @@ def set_stack_trace(stack: list[str]):
current_meta["stack_trace"] = "".join(stack)
@compatibility(is_backward_compatible=False)
def _enter_annotation_context(annotation_dict: dict, context_state: dict):
"""
Apply annotations to the current meta context.
Args:
annotation_dict (dict): The annotations to apply.
context_state (dict): Dictionary to store the context state for restoration.
Returns:
dict: The context state that was updated.
"""
global current_meta
context_state["has_custom"] = "custom" in current_meta
context_state["old_custom"] = {}
# cannot use `old_custom = copy.copy(current_meta.get("custom", {}))` here,
# as dynamo doesn't support copy.copy()
for k, v in current_meta.get("custom", {}).items():
context_state["old_custom"][k] = v # noqa: PERF403
if not context_state["has_custom"]:
current_meta["custom"] = {}
# Update with all key-value pairs from the input dict
current_meta["custom"].update(annotation_dict)
return context_state
@compatibility(is_backward_compatible=False)
def _exit_annotation_context(context_state: dict):
"""
Restore the original meta context state.
Args:
context_state (dict): The context state to restore from.
"""
global current_meta
if context_state["has_custom"]:
# Restore the original custom dict
current_meta["custom"] = context_state["old_custom"]
else:
del current_meta["custom"]
@compatibility(is_backward_compatible=False)
@contextmanager
def annotate(annotation_dict: dict):
@ -267,6 +313,7 @@ def annotate(annotation_dict: dict):
Args:
annotation_dict (dict): A dictionary of custom key-value pairs to inject
into the FX trace metadata.
If you want to dynamo trace to work, the `annotation_dict` can only contain constants
Example:
>>> with annotate({"source": "custom_pass", "tag": 42}):
@ -276,26 +323,14 @@ def annotate(annotation_dict: dict):
global current_meta
has_custom = "custom" in current_meta
old_custom = {}
# cannot use `old_custom = copy.copy(current_meta.get("custom", {}))` here,
# as dynamo doesn't support copy.copy()
for k, v in current_meta.get("custom", {}).items():
old_custom[k] = v # noqa: PERF403
# Note: we need the two functions below for dynamo to intercept them.
# If you want to add any logic, please add to the _enter and _exit functions
old_custom: dict[str, Any] = {}
try:
if not has_custom:
current_meta["custom"] = {}
# Update with all key-value pairs from the input dict
current_meta["custom"].update(annotation_dict)
_enter_annotation_context(annotation_dict, old_custom)
yield
finally:
if has_custom:
# Restore the original custom dict
current_meta["custom"] = old_custom
else:
del current_meta["custom"]
_exit_annotation_context(old_custom)
@compatibility(is_backward_compatible=False)