Support propagating custom meta field to backward graph nodes (#164174)

# Propagate custom meta data to backward

Support propagating the user annotation tags to backward graph, by extending the `copy_fwd_metadata_to_bw_nodes` utils (recommended by @xmfan , thanks!).

Example annotation API (added in https://github.com/pytorch/pytorch/pull/163673):

```
class M(torch.nn.Module):
    def forward(self, x):
        with fx_traceback.annotate({"pp_stage": 0}):
            with fx_traceback.annotate({"fdsp_bucket": 0}):
                x = x + 1
            x = x - 2
            with fx_traceback.annotate({"cuda_stream": 2, "fsdp_bucket": 1}):
                x = x * 2
        x = x / 3
        return x
```

Assumptions (some inherited from https://github.com/pytorch/pytorch/pull/126573):

- I am trusting the seq_nr mapping introduced to aot_autograd nodes in https://github.com/pytorch/pytorch/pull/103129
- I am also trusting that the forward is single threaded, since seq_nr is thread local.  If this isn't always true, we'll need to also plumb thread_id through the same machinery which is populating seq_nr.
- **(This is changed in this PR!) I assume all backward graph nodes has "is_backward" for 'partitioner_tag', and all other nodes are forward graph nodes**.  If we don't run export before `aot_export_join_with_descriptors`, then none of the nodes has "nn_module_stack" in node meta. If we do run export first, then we don't need this change.
- I copy "custom" node meta from forward to backward graph nodes.

Question:
- Is it a good idea to copy all "custom" node meta? Or should we create a dedicated key in custom node meta to be copied? @SherlockNoMad
- Do we expect people to run export before using `aot_export_join_with_descriptors`?
- Can we assume the following for graph produced by `aot_export_join_with_descriptors`? "all backward graph nodes has "is_backward" for 'partitioner_tag', and all other nodes are forward graph nodes". Maybe this is a question for @ezyang

```
python test/functorch/test_aot_joint_with_descriptors.py -k test_preserve_
python test/export/test_export.py -k preserve_anno
python test/distributed/tensor/test_dtensor_export.py
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164174
Approved by: https://github.com/xmfan, https://github.com/SherlockNoMad
This commit is contained in:
Shangdi Yu
2025-10-04 05:03:28 +00:00
committed by PyTorch MergeBot
parent 35c4130fd1
commit 6b768e1890
5 changed files with 173 additions and 28 deletions

View File

@ -5,6 +5,7 @@ import unittest
import torch
import torch.distributed as dist
import torch.fx.traceback as fx_traceback
from torch._dynamo.functional_export import _dynamo_graph_capture_for_export
from torch._functorch.aot_autograd import aot_export_joint_with_descriptors
from torch._functorch.partitioners import min_cut_rematerialization_partition
@ -37,6 +38,18 @@ class SimpleModel(torch.nn.Module):
return self.mlp_1(self.mlp_0(input))
class SimpleModelAnnotated(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.mlp_0 = MLPModule(device)
self.mlp_1 = MLPModule(device)
def forward(self, input):
with fx_traceback.annotate({"pp_stage": 0}):
x = self.mlp_0(input)
return self.mlp_1(x)
def strict_export_and_aot_export_joint_with_descriptors(model, inputs):
# needed for stric export
torch.utils._pytree.register_constant(DTensorSpec)
@ -90,7 +103,7 @@ class DTensorExportTest(TestCase):
)
self.device_type = "cuda"
def _run_test(self, export_fn):
def _run_test(self, export_fn, test_annotation=False):
dp_degree = 2
tp_degree = self.world_size // dp_degree
@ -101,7 +114,11 @@ class DTensorExportTest(TestCase):
mesh_dim_names=["dp", "tp"],
)
model = SimpleModel(self.device_type)
model = None
if test_annotation:
model = SimpleModelAnnotated(self.device_type)
else:
model = SimpleModel(self.device_type)
parallelize_plan = {
"mlp_0.net1": ColwiseParallel(),
"mlp_0.net2": RowwiseParallel(),
@ -131,6 +148,116 @@ class DTensorExportTest(TestCase):
1,
)
if test_annotation:
def has_tag(node):
return "custom" in node.meta and node.meta["custom"] == {"pp_stage": 0}
def marked_nodes(gm):
return [
node.name
for node in gm.graph.nodes
if has_tag(node) and node.op == "call_function"
]
def unmarked_nodes(gm):
return [
node.name
for node in gm.graph.nodes
if not has_tag(node) and node.op == "call_function"
]
marked_nodes_fw = [
"t",
"addmm",
"view",
"relu",
"view_1",
"t_1",
"div",
"addmm_1",
"all_reduce",
"wait_tensor",
"view_2",
"t_12",
]
unmarked_nodes_fw = [
"view_3",
"t_2",
"addmm_2",
"view_4",
"relu_1",
"view_5",
"t_3",
"div_1",
"addmm_3",
"all_reduce_1",
"wait_tensor_1",
"view_6",
"t_4",
"t_8",
]
marked_nodes_bw = [
"mm_4",
"t_13",
"view_1",
"mm_5",
"t_14",
"sum_3",
"view_9",
"t_15",
"detach",
"detach_1",
"detach_6",
"detach_7",
"threshold_backward_1",
"t_16",
"mm_6",
"t_17",
"sum_4",
"view_10",
"t_18",
]
unmarked_nodes_bw = [
"mm",
"t_5",
"view_5",
"mm_1",
"t_6",
"sum_1",
"view_7",
"t_7",
"detach_2",
"detach_3",
"detach_4",
"detach_5",
"threshold_backward",
"mm_2",
"t_9",
"mm_3",
"t_10",
"sum_2",
"view_8",
"t_11",
"all_reduce_2",
"wait_tensor_2",
]
self.assertEqual(marked_nodes(fw_gm), marked_nodes_fw)
self.assertEqual(unmarked_nodes(fw_gm), unmarked_nodes_fw)
self.assertEqual(marked_nodes(bw_gm), marked_nodes_bw)
self.assertEqual(unmarked_nodes(bw_gm), unmarked_nodes_bw)
self.assertEqual(
set(marked_nodes(joint_gm)), set(marked_nodes_fw + marked_nodes_bw)
)
self.assertEqual(
set(unmarked_nodes(joint_gm)),
set(unmarked_nodes_fw + unmarked_nodes_bw),
)
@parametrize(
"export_fn",
[
@ -150,6 +277,9 @@ class DTensorExportTest(TestCase):
def test_strict_export_parallelize_module_with_dtensor_input(self):
self._run_test(strict_export_and_aot_export_joint_with_descriptors)
def test_annotate_aot_export_joint_with_descriptors_alone(self):
self._run_test(aot_export_joint_with_descriptors_alone, True)
instantiate_parametrized_tests(DTensorExportTest)

View File

@ -910,7 +910,7 @@ SeqNr|OrigAten|SrcFn|FwdSrcFn
7|aten.view.default||l__self___fc1
6|aten.t.default||l__self___fc1
5|aten.view.default||l__self___fc1
4|aten.view.default||
4|aten.view.default||flatten
2|aten.detach.default||l__self___relu1
2|aten.detach.default||l__self___relu1
2|aten.threshold_backward.default||l__self___relu1

View File

@ -15573,6 +15573,11 @@ def forward(self, x):
test_serdes=True,
)
@testing.expectedFailureTrainingIRToRunDecomp
@testing.expectedFailureRetraceability
@testing.expectedFailureStrictV2
@testing.expectedFailureStrict # annotation needs to be handled in dynamo
@testing.expectedFailureSerDer
def test_preserve_annotation(self):
class M(torch.nn.Module):
def forward(self, x):
@ -15591,17 +15596,22 @@ def forward(self, x):
ep = export(m, (torch.randn(10),))
for node in ep.graph.nodes:
if node.target == torch.ops.aten.add.default:
if node.op in ("placeholder", "output"):
continue
if node.target == torch.ops.aten.add.Tensor:
self.assertTrue(node.meta["custom"], {"pp_stage": 0, "fdsp_bucket": 0})
if node.target == torch.ops.aten.sub.default:
elif node.target == torch.ops.aten.sub.Tensor:
self.assertTrue(node.meta["custom"], {"pp_stage": 0})
if node.target == torch.ops.aten.mul.default:
elif node.target == torch.ops.aten.mul.Tensor:
self.assertTrue(
node.meta["custom"],
{"pp_stage": 0, "cuda_stream": 2, "fsdp_bucket": 1},
)
if node.target == torch.ops.aten.div.default:
self.assertTrue(node.meta["custom"], {})
elif node.target == torch.ops.aten.div.Tensor:
if "custom" in node.meta:
self.assertTrue(node.meta["custom"], {})
else:
raise AssertionError(f"Node not checked: {node}, {node.target}")
def test_dynamic_shapes_serdes_generic(self):
from torch._export.serde.dynamic_shapes import (

View File

@ -13,6 +13,7 @@ import torch.fx.traceback as fx_traceback
import torch.nn as nn
import torch.utils._pytree as pytree
from torch._decomp import decomposition_table
from torch._dynamo.functional_export import _dynamo_graph_capture_for_export
from torch._dynamo.testing import normalize_gm
from torch._functorch._aot_autograd.descriptors import (
BufferAOTInput,
@ -777,36 +778,34 @@ class inner_f(torch.nn.Module):
inputs = (torch.randn(4, 3),)
for with_export in [True, False]:
for with_export in [False]: # TODO: make dynamo work for annotation
with ExitStack() as stack:
model = None
with fx_traceback.preserve_node_meta():
if with_export:
ep = torch.export.export(SimpleLinear(), inputs)
model = ep.module()
with torch._dynamo.config.patch(install_free_tensors=True):
# TODO: switch to use the official graph_capture API once it is ready
model = _dynamo_graph_capture_for_export(model)(*inputs)
else:
model = SimpleLinear()
joint_with_descriptors = aot_export_joint_with_descriptors(
stack, model, inputs, decompositions=decomposition_table
stack, model, inputs, decompositions={}
)
for node in joint_with_descriptors.graph_module.graph.nodes:
if (
node.target
in (
torch.ops.prims.transpose.default,
torch.ops.aten.mm.default,
torch.ops.prims.mul.default,
torch.ops.prims.broadcast_in_dim.default,
torch.ops.prims.add.default,
)
# TODO: add annotation to backward graph nodes
and node.meta.get("partitioner_tag") != "is_backward"
if node.op in ("placeholder", "output"):
continue
if node.target != torch.ops.aten.sub.Tensor and node.op not in (
"placeholder",
"output",
):
self.assertTrue(node.meta["custom"], {"pp_stage": 0})
if node.target == torch.ops.aten.sub.default:
self.assertTrue(node.meta.get("custom", {}), {})
elif node.target == torch.ops.aten.sub.Tensor:
if "custom" in node.meta:
self.assertTrue(node.meta.get("custom", {}), {})
else:
raise AssertionError(f"Node not checked: {node}, {node.target}")
if __name__ == "__main__":

View File

@ -422,14 +422,18 @@ def copy_fwd_metadata_to_bw_nodes(fx_g):
# the descendants of graph inputs corresponding to fwd inputs, didn't
# seem obvious at first glance on how to partition graph inputs into
# fwd vs bwd without relying on string names.
return "nn_module_stack" in node.meta and "seq_nr" in node.meta
return (
node.meta.get("partitioner_tag") != "is_backward" and "seq_nr" in node.meta
)
def _is_backward_node_with_seq_nr(node):
# For now, assume that if nn_module_stack_metadata is not populated,
# this node is from the backward. Ignore nodes without `seq_nr`.
# TODO(future): there is likely a less brittle way to do this, same
# as with the forward.
return ("nn_module_stack" not in node.meta) and "seq_nr" in node.meta
return (
node.meta.get("partitioner_tag") == "is_backward" and "seq_nr" in node.meta
)
fwd_seq_nr_to_node = {}
for node in fx_g.graph.nodes:
@ -449,8 +453,10 @@ def copy_fwd_metadata_to_bw_nodes(fx_g):
# fwd_node should always exist, but handle non-existence just in case
fwd_node = fwd_seq_nr_to_node.get(node.meta["seq_nr"])
if fwd_node is not None:
node.meta["fwd_nn_module_stack"] = fwd_node.meta["nn_module_stack"]
node.meta["fwd_nn_module_stack"] = fwd_node.meta.get("nn_module_stack")
node.meta["fwd_source_fn_stack"] = fwd_node.meta.get("source_fn_stack")
# TODO: better to change to a specific field of custom?
node.meta["custom"] = fwd_node.meta.get("custom")
def register_buffer_assignment_hook(mod, assigned_buffers):