[annotate] Annotation should be mapped across submod (#165202)

The match for backward nodes might be in a different submod, so we should check all submod for potential matches.

In flex attention, this could happen if `mask_mod` has operations (such as index) that increase the seq_nr of the forward graph nodes. Then the backward flex_attention nodes cannot find a match in its own subgraph.

```
python test/functorch/test_aot_joint_with_descriptors.py -k preserve_annotate
```

Also tested on torchtitan joint_graph_runner branch. The flex_attention backward nodes are annotated now.

```
NGPU=8   CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml"   LOG_RANK=0   TRAIN_FILE="torchtitan.train"   TORCHFT_LIGHTHOUSE="http://localhost:29510"   PYTORCH_ALLOC_CONF="expandable_segments:True"   torchrun     --nproc_per_node=8     --rdzv_backend c10d     --rdzv_endpoint="localhost:0"     --local-ranks-filter 0     --role rank     --tee 3     -m torchtitan.train     --job.config_file ./torchtitan/models/llama3/train_configs/debug_model.toml     --model.name joint_graph_runner.llama3     --compile.enable     --parallelism.data_parallel_shard_degree=2     --parallelism.tensor_parallel_degree=4     --model.flavor=debugmodel_flex_attn
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165202
Approved by: https://github.com/SherlockNoMad
This commit is contained in:
Shangdi Yu
2025-10-14 16:19:38 +00:00
committed by PyTorch MergeBot
parent c9b2a09530
commit 5eddbb5e47
4 changed files with 218 additions and 87 deletions

View File

@ -18,20 +18,6 @@ def checkpoint_wrapper(fn):
class AnnotateTests(torch._dynamo.test_case.TestCase): class AnnotateTests(torch._dynamo.test_case.TestCase):
def get_custom_metadata(self, gm):
def helper(gm):
custom_metadata = []
for node in gm.graph.nodes:
if hasattr(node, "meta") and node.meta.get("custom", None):
custom_metadata.append((node.op, node.name, node.meta["custom"]))
if node.op == "get_attr" and isinstance(
getattr(gm, node.target), torch.fx.GraphModule
):
custom_metadata.append(helper(getattr(gm, node.target)))
return custom_metadata
return "\n".join(str(x) for x in helper(gm))
def test_annotations(self): def test_annotations(self):
class Mod(torch.nn.Module): class Mod(torch.nn.Module):
def forward(self, x): def forward(self, x):
@ -53,9 +39,9 @@ class AnnotateTests(torch._dynamo.test_case.TestCase):
self.assertEqual(len(backend.fw_graphs), 1) self.assertEqual(len(backend.fw_graphs), 1)
self.assertEqual(len(backend.bw_graphs), 1) self.assertEqual(len(backend.bw_graphs), 1)
dynamo_metadata = self.get_custom_metadata(backend.graphs[0]) dynamo_metadata = fx_traceback._get_custom_metadata(backend.graphs[0])
fw_metadata = self.get_custom_metadata(backend.fw_graphs[0]) fw_metadata = fx_traceback._get_custom_metadata(backend.fw_graphs[0])
bw_metadata = self.get_custom_metadata(backend.bw_graphs[0]) bw_metadata = fx_traceback._get_custom_metadata(backend.bw_graphs[0])
self.assertExpectedInline( self.assertExpectedInline(
str(dynamo_metadata), str(dynamo_metadata),
"""\ """\
@ -97,9 +83,9 @@ class AnnotateTests(torch._dynamo.test_case.TestCase):
self.assertEqual(len(backend.fw_graphs), 1) self.assertEqual(len(backend.fw_graphs), 1)
self.assertEqual(len(backend.bw_graphs), 1) self.assertEqual(len(backend.bw_graphs), 1)
dynamo_metadata = self.get_custom_metadata(backend.graphs[0]) dynamo_metadata = fx_traceback._get_custom_metadata(backend.graphs[0])
fw_metadata = self.get_custom_metadata(backend.fw_graphs[0]) fw_metadata = fx_traceback._get_custom_metadata(backend.fw_graphs[0])
bw_metadata = self.get_custom_metadata(backend.bw_graphs[0]) bw_metadata = fx_traceback._get_custom_metadata(backend.bw_graphs[0])
self.assertExpectedInline( self.assertExpectedInline(
str(dynamo_metadata), str(dynamo_metadata),
"""\ """\
@ -140,9 +126,9 @@ class AnnotateTests(torch._dynamo.test_case.TestCase):
self.assertEqual(len(backend.fw_graphs), 1) self.assertEqual(len(backend.fw_graphs), 1)
self.assertEqual(len(backend.bw_graphs), 1) self.assertEqual(len(backend.bw_graphs), 1)
dynamo_metadata = self.get_custom_metadata(backend.graphs[0]) dynamo_metadata = fx_traceback._get_custom_metadata(backend.graphs[0])
fw_metadata = self.get_custom_metadata(backend.fw_graphs[0]) fw_metadata = fx_traceback._get_custom_metadata(backend.fw_graphs[0])
bw_metadata = self.get_custom_metadata(backend.bw_graphs[0]) bw_metadata = fx_traceback._get_custom_metadata(backend.bw_graphs[0])
self.assertExpectedInline( self.assertExpectedInline(
str(dynamo_metadata), str(dynamo_metadata),
"""[('call_function', 'p', {'stage': 0})]""", # noqa: B950 """[('call_function', 'p', {'stage': 0})]""", # noqa: B950
@ -198,9 +184,9 @@ class AnnotateTests(torch._dynamo.test_case.TestCase):
self.assertEqual(len(backend.fw_graphs), 1) self.assertEqual(len(backend.fw_graphs), 1)
self.assertEqual(len(backend.bw_graphs), 1) self.assertEqual(len(backend.bw_graphs), 1)
dynamo_metadata = self.get_custom_metadata(backend.graphs[0]) dynamo_metadata = fx_traceback._get_custom_metadata(backend.graphs[0])
fw_metadata = self.get_custom_metadata(backend.fw_graphs[0]) fw_metadata = fx_traceback._get_custom_metadata(backend.fw_graphs[0])
bw_metadata = self.get_custom_metadata(backend.bw_graphs[0]) bw_metadata = fx_traceback._get_custom_metadata(backend.bw_graphs[0])
self.assertExpectedInline( self.assertExpectedInline(
str(dynamo_metadata), str(dynamo_metadata),
"""\ """\
@ -243,11 +229,11 @@ class AnnotateTests(torch._dynamo.test_case.TestCase):
('call_function', 'detach_2', {'compile_inductor': 0}) ('call_function', 'detach_2', {'compile_inductor': 0})
('call_function', 'detach_3', {'compile_inductor': 0}) ('call_function', 'detach_3', {'compile_inductor': 0})
('get_attr', 'fw_graph0', {'compile_inductor': 0}) ('get_attr', 'fw_graph0', {'compile_inductor': 0})
[] [('placeholder', 'arg0_1', {'compile_inductor': 0}), ('placeholder', 'arg1_1', {'compile_inductor': 0}), ('placeholder', 'arg2_1', {'compile_inductor': 0}), ('placeholder', 'arg3_1', {'compile_inductor': 0}), ('placeholder', 'arg4_1', {'compile_inductor': 0}), ('call_function', 'mul', {'compile_inductor': 0}), ('output', 'output', {'compile_inductor': 0})]
('get_attr', 'joint_graph0', {'compile_inductor': 0}) ('get_attr', 'joint_graph0', {'compile_inductor': 0})
[] [('placeholder', 'arg0_1', {'compile_inductor': 0}), ('placeholder', 'arg1_1', {'compile_inductor': 0}), ('placeholder', 'arg2_1', {'compile_inductor': 0}), ('placeholder', 'arg3_1', {'compile_inductor': 0}), ('placeholder', 'arg4_1', {'compile_inductor': 0}), ('placeholder', 'arg5_1', {'compile_inductor': 0}), ('call_function', 'mul_1', {'compile_inductor': 0}), ('call_function', 'mul_2', {'compile_inductor': 0}), ('call_function', 'add', {'compile_inductor': 0}), ('output', 'output', {'compile_inductor': 0})]
('get_attr', 'mask_graph0', {'compile_inductor': 0}) ('get_attr', 'mask_graph0', {'compile_inductor': 0})
[('call_function', 'ge', {'compile_inductor': 0})] [('placeholder', 'arg0_1', {'compile_inductor': 0}), ('placeholder', 'arg1_1', {'compile_inductor': 0}), ('placeholder', 'arg2_1', {'compile_inductor': 0}), ('placeholder', 'arg3_1', {'compile_inductor': 0}), ('call_function', 'ge', {'compile_inductor': 0}), ('output', 'output', {'compile_inductor': 0})]
('call_function', 'flex_attention_backward', {'compile_inductor': 0}) ('call_function', 'flex_attention_backward', {'compile_inductor': 0})
('call_function', 'getitem_3', {'compile_inductor': 0}) ('call_function', 'getitem_3', {'compile_inductor': 0})
('call_function', 'getitem_4', {'compile_inductor': 0}) ('call_function', 'getitem_4', {'compile_inductor': 0})

View File

@ -37,7 +37,30 @@ from torch._functorch.aot_autograd import (
aot_export_joint_with_descriptors, aot_export_joint_with_descriptors,
) )
from torch._guards import tracing, TracingContext from torch._guards import tracing, TracingContext
from torch.testing._internal.common_utils import run_tests, TestCase from torch.nn.attention.flex_attention import create_block_mask, flex_attention
from torch.testing._internal.common_utils import requires_cuda, run_tests, TestCase
def graph_capture(model, inputs, with_export):
gm = model
fake_mode = None
if with_export:
with (
torch._dynamo.config.patch(install_free_tensors=True),
fx_traceback.preserve_node_meta(),
):
# TODO: switch to use the official graph_capture API once it is ready
gm = _dynamo_graph_capture_for_export(model)(*inputs)
fake_mode = gm.meta.get("fake_mode", None)
with tracing(TracingContext(fake_mode)):
with ExitStack() as stack:
joint_with_descriptors = aot_export_joint_with_descriptors(
stack,
model,
inputs,
)
return joint_with_descriptors.graph_module
class TestAOTJointWithDescriptors(TestCase): class TestAOTJointWithDescriptors(TestCase):
@ -778,40 +801,128 @@ class inner_f(torch.nn.Module):
return y - 1 return y - 1
inputs = (torch.randn(4, 3),) inputs = (torch.randn(4, 3),)
model = SimpleLinear()
for with_export in [False]: # TODO: make dynamo work for annotation for with_export in [True, False]:
with ExitStack() as stack: graph_module = graph_capture(model, inputs, with_export)
model = SimpleLinear() custom_metadata = fx_traceback._get_custom_metadata(graph_module)
fake_mode = None self.assertExpectedInline(
str(custom_metadata),
"""\
('call_function', 't', {'pp_stage': 0})
('call_function', 'addmm', {'pp_stage': 0})
('call_function', 't_1', {'pp_stage': 0})
('call_function', 'mm', {'pp_stage': 0})
('call_function', 't_2', {'pp_stage': 0})
('call_function', 'sum_1', {'pp_stage': 0})
('call_function', 'view', {'pp_stage': 0})
('call_function', 't_3', {'pp_stage': 0})""",
)
stack.enter_context(fx_traceback.preserve_node_meta()) @requires_cuda
def test_preserve_annotate_flex_attention(self):
def score_mod(score, b, h, m, n):
return score
if with_export: def _get_block_causal_mask_mod(seq_idx):
stack.enter_context( def block_causal_mask(b, h, q_idx, kv_idx):
torch._dynamo.config.patch(install_free_tensors=True) # must use this more complicated mask_mod so autograd seq_nr increases
return (seq_idx[b, q_idx] == seq_idx[b, kv_idx]) & (q_idx >= kv_idx)
return block_causal_mask
a = 12
b = 24
batch_size = 2
seqlen = a * b
device = "cuda"
# Create seq_idx tensor - maps each position to a document/sequence ID
# Example: Split sequence into 2 documents for each batch
# First half (0:384) belongs to document 0, second half (384:768) to document 1
seq_idx = torch.zeros(batch_size, seqlen, dtype=torch.int32, device=device)
seq_idx[:, seqlen // 2 :] = 1 # Second half belongs to document 1
# Get the mask_mod function with seq_idx captured in closure
mask_mod = _get_block_causal_mask_mod(seq_idx)
# Create block_mask with the mask_mod function (which only takes 4 args)
# Note: We don't compile create_block_mask itself, just flex_attention
block_mask = create_block_mask(mask_mod, None, None, seqlen, seqlen)
class FlexAttentionModule(torch.nn.Module):
"""Flex attention submodule similar to the sdpa in Llama3 Attention"""
def forward(self, xq, xk, xv):
"""
Args:
xq: Query tensor (bs, n_heads, seqlen, head_dim)
xk: Key tensor (bs, n_heads, seqlen, head_dim)
xv: Value tensor (bs, n_heads, seqlen, head_dim)
Returns:
Output tensor (bs, n_heads, seqlen, head_dim)
"""
with fx_traceback.annotate({"compile_with_inductor": "flex_attention"}):
output = flex_attention(
xq, xk, xv, block_mask=block_mask, score_mod=score_mod
) )
# TODO: switch to use the official graph_capture API once it is ready return output
model = _dynamo_graph_capture_for_export(model)(*inputs)
fake_mode = model.meta.get("fake_mode", None)
stack.enter_context(tracing(TracingContext(fake_mode))) # Model configuration
joint_with_descriptors = aot_export_joint_with_descriptors( n_heads = 4
stack, model, inputs, decompositions={} head_dim = 64
)
for node in joint_with_descriptors.graph_module.graph.nodes: # Create input tensors in the shape expected by FlexAttentionModule
if node.op in ("placeholder", "output"): # Shape: (bs, n_heads, seqlen, head_dim)
continue xq = torch.randn(
if node.target != torch.ops.aten.sub.Tensor and node.op not in ( batch_size, n_heads, seqlen, head_dim, requires_grad=True, device=device
"placeholder", )
"output", xk = torch.randn(
): batch_size, n_heads, seqlen, head_dim, requires_grad=True, device=device
self.assertTrue(node.meta["custom"], {"pp_stage": 0}) )
elif node.target == torch.ops.aten.sub.Tensor: xv = torch.randn(
if "custom" in node.meta: batch_size, n_heads, seqlen, head_dim, requires_grad=True, device=device
self.assertTrue(node.meta.get("custom", {}), {}) )
else:
raise AssertionError(f"Node not checked: {node}, {node.target}") model = FlexAttentionModule().to(device)
inputs = (xq, xk, xv)
gm = graph_capture(model, inputs, with_export=True)
custom_metadata = fx_traceback._get_custom_metadata(gm)
# not using assertExpectedInline because some CI runs has fewer detach nodes in graph
# than other CI runs, so we can't use a fixed string to compare against
self.assertTrue(
"('get_attr', 'sdpa_score0', {'compile_with_inductor': 'flex_attention'})"
in custom_metadata
)
self.assertTrue(
"('get_attr', 'sdpa_mask0', {'compile_with_inductor': 'flex_attention'})"
in custom_metadata
)
self.assertTrue(
"('call_function', 'flex_attention', {'compile_with_inductor': 'flex_attention'})"
in custom_metadata
)
self.assertTrue(
"('get_attr', 'fw_graph0', {'compile_with_inductor': 'flex_attention'})"
in custom_metadata
)
self.assertTrue(
"('get_attr', 'joint_graph0', {'compile_with_inductor': 'flex_attention'})"
in custom_metadata
)
self.assertTrue(
"('get_attr', 'mask_graph0', {'compile_with_inductor': 'flex_attention'})"
in custom_metadata
)
self.assertTrue(
"('call_function', 'flex_attention_backward', {'compile_with_inductor': 'flex_attention'})"
in custom_metadata
)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -404,28 +404,28 @@ def root_module_when_exporting_non_strict(flat_fn):
return None return None
def _copy_fwd_metadata_to_bw_nodes(fx_g): def _is_forward_node_with_seq_nr(node: torch.fx.Node) -> bool:
def _is_forward_node_with_seq_nr(node): # For now, assume that if nn_module_stack_metadata is populated, this
# For now, assume that if nn_module_stack_metadata is populated, this # node is from the forward. Ignore nodes without `seq_nr`.
# node is from the forward. Ignore nodes without `seq_nr`. # TODO(future): there is likely a less brittle way to do this by walking
# TODO(future): there is likely a less brittle way to do this by walking # the descendants of graph inputs corresponding to fwd inputs, didn't
# the descendants of graph inputs corresponding to fwd inputs, didn't # seem obvious at first glance on how to partition graph inputs into
# seem obvious at first glance on how to partition graph inputs into # fwd vs bwd without relying on string names.
# fwd vs bwd without relying on string names. return node.meta.get("partitioner_tag") != "is_backward" and "seq_nr" in node.meta
return (
node.meta.get("partitioner_tag") != "is_backward" and "seq_nr" in node.meta
)
def _is_backward_node_with_seq_nr(node):
# For now, assume that if nn_module_stack_metadata is not populated,
# this node is from the backward. Ignore nodes without `seq_nr`.
# TODO(future): there is likely a less brittle way to do this, same
# as with the forward.
return (
node.meta.get("partitioner_tag") == "is_backward" and "seq_nr" in node.meta
)
fwd_seq_nr_to_node = {} def _is_backward_node_with_seq_nr(node: torch.fx.Node) -> bool:
# For now, assume that if nn_module_stack_metadata is not populated,
# this node is from the backward. Ignore nodes without `seq_nr`.
# TODO(future): there is likely a less brittle way to do this, same
# as with the forward.
return node.meta.get("partitioner_tag") == "is_backward" and "seq_nr" in node.meta
def _collect_fwd_nodes_from_subgraph(
fx_g: torch.fx.GraphModule, fwd_seq_nr_to_node: dict[str, torch.fx.Node]
) -> None:
"""Collect forward nodes from a single subgraph into the global mapping."""
for node in fx_g.graph.nodes: for node in fx_g.graph.nodes:
if not _is_forward_node_with_seq_nr(node): if not _is_forward_node_with_seq_nr(node):
continue continue
@ -435,11 +435,17 @@ def _copy_fwd_metadata_to_bw_nodes(fx_g):
# that the current op did not create an autograd node, and there # that the current op did not create an autograd node, and there
# is no corresponding backward node, so we skip. # is no corresponding backward node, so we skip.
continue continue
fwd_seq_nr_to_node[node.meta["seq_nr"]] = node fwd_seq_nr_to_node[seq_nr] = node
def _copy_metadata_to_bw_nodes_in_subgraph(
fx_g: torch.fx.GraphModule, fwd_seq_nr_to_node: dict[str, torch.fx.Node]
) -> None:
"""Copy metadata from forward nodes to backward nodes in a single subgraph."""
for node in fx_g.graph.nodes: for node in fx_g.graph.nodes:
if not _is_backward_node_with_seq_nr(node): if not _is_backward_node_with_seq_nr(node):
continue continue
# fwd_node should always exist, but handle non-existence just in case # fwd_node should always exist, but handle non-existence just in case
fwd_node = fwd_seq_nr_to_node.get(node.meta["seq_nr"]) fwd_node = fwd_seq_nr_to_node.get(node.meta["seq_nr"])
if fwd_node is not None: if fwd_node is not None:
@ -449,7 +455,7 @@ def _copy_fwd_metadata_to_bw_nodes(fx_g):
node.meta["custom"] = fwd_node.meta.get("custom") node.meta["custom"] = fwd_node.meta.get("custom")
def copy_fwd_metadata_to_bw_nodes(fx_g): def copy_fwd_metadata_to_bw_nodes(fx_g: torch.fx.GraphModule) -> None:
""" """
Input: `fx_g` which contains the joint fwd+bwd FX graph created by Input: `fx_g` which contains the joint fwd+bwd FX graph created by
aot_autograd. aot_autograd.
@ -458,15 +464,25 @@ def copy_fwd_metadata_to_bw_nodes(fx_g):
to backward nodes, using the `seq_nr` field as a one-to-many mapping to backward nodes, using the `seq_nr` field as a one-to-many mapping
from forward node to backward node. This metadata is useful for performance from forward node to backward node. This metadata is useful for performance
profiling and debugging. profiling and debugging.
This function supports matching forward and backward nodes across different
subgraphs (e.g., in recursive submodules from HOPs), enabling backward nodes
in any submodule to match forward nodes in any submodule.
""" """
# Copy the metadata recursively - useful for HOPs # Build a global mapping of seq_nr to forward nodes across all subgraphs
for node in fx_g.graph.nodes: fwd_seq_nr_to_node: dict[str, torch.fx.Node] = {}
if node.op == "get_attr":
submod = getattr(fx_g, node.target) # First pass: collect all forward nodes from all subgraphs
if isinstance(submod, torch.fx.GraphModule): for submod in fx_g.modules():
copy_fwd_metadata_to_bw_nodes(submod) if isinstance(submod, torch.fx.GraphModule):
_copy_fwd_metadata_to_bw_nodes(fx_g) _collect_fwd_nodes_from_subgraph(submod, fwd_seq_nr_to_node)
# Second pass: copy metadata to backward nodes in all subgraphs
# using the global forward mapping
for submod in fx_g.modules():
if isinstance(submod, torch.fx.GraphModule):
_copy_metadata_to_bw_nodes_in_subgraph(submod, fwd_seq_nr_to_node)
def register_buffer_assignment_hook(mod, assigned_buffers): def register_buffer_assignment_hook(mod, assigned_buffers):

View File

@ -10,6 +10,7 @@ from torch._utils_internal import signpost_event
from ._compatibility import compatibility from ._compatibility import compatibility
from .graph import Graph from .graph import Graph
from .graph_module import GraphModule
from .node import Node from .node import Node
@ -388,3 +389,20 @@ def get_graph_provenance_json(graph: Graph) -> dict[str, Any]:
}, },
) )
return {} return {}
def _get_custom_metadata(gm: GraphModule) -> str:
assert isinstance(gm, GraphModule)
def helper(gm: GraphModule):
custom_metadata = []
for node in gm.graph.nodes:
if hasattr(node, "meta") and node.meta.get("custom", None):
custom_metadata.append((node.op, node.name, node.meta["custom"]))
if node.op == "get_attr" and isinstance(
getattr(gm, node.target), GraphModule
):
custom_metadata.append(helper(getattr(gm, node.target)))
return custom_metadata
return "\n".join(str(x) for x in helper(gm))