[annotate] Annotation should be mapped across submod (#165202)

The match for backward nodes might be in a different submod, so we should check all submod for potential matches.

In flex attention, this could happen if `mask_mod` has operations (such as index) that increase the seq_nr of the forward graph nodes. Then the backward flex_attention nodes cannot find a match in its own subgraph.

```
python test/functorch/test_aot_joint_with_descriptors.py -k preserve_annotate
```

Also tested on torchtitan joint_graph_runner branch. The flex_attention backward nodes are annotated now.

```
NGPU=8   CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml"   LOG_RANK=0   TRAIN_FILE="torchtitan.train"   TORCHFT_LIGHTHOUSE="http://localhost:29510"   PYTORCH_ALLOC_CONF="expandable_segments:True"   torchrun     --nproc_per_node=8     --rdzv_backend c10d     --rdzv_endpoint="localhost:0"     --local-ranks-filter 0     --role rank     --tee 3     -m torchtitan.train     --job.config_file ./torchtitan/models/llama3/train_configs/debug_model.toml     --model.name joint_graph_runner.llama3     --compile.enable     --parallelism.data_parallel_shard_degree=2     --parallelism.tensor_parallel_degree=4     --model.flavor=debugmodel_flex_attn
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165202
Approved by: https://github.com/SherlockNoMad
This commit is contained in:
Shangdi Yu
2025-10-14 16:19:38 +00:00
committed by PyTorch MergeBot
parent c9b2a09530
commit 5eddbb5e47
4 changed files with 218 additions and 87 deletions

View File

@ -18,20 +18,6 @@ def checkpoint_wrapper(fn):
class AnnotateTests(torch._dynamo.test_case.TestCase):
def get_custom_metadata(self, gm):
def helper(gm):
custom_metadata = []
for node in gm.graph.nodes:
if hasattr(node, "meta") and node.meta.get("custom", None):
custom_metadata.append((node.op, node.name, node.meta["custom"]))
if node.op == "get_attr" and isinstance(
getattr(gm, node.target), torch.fx.GraphModule
):
custom_metadata.append(helper(getattr(gm, node.target)))
return custom_metadata
return "\n".join(str(x) for x in helper(gm))
def test_annotations(self):
class Mod(torch.nn.Module):
def forward(self, x):
@ -53,9 +39,9 @@ class AnnotateTests(torch._dynamo.test_case.TestCase):
self.assertEqual(len(backend.fw_graphs), 1)
self.assertEqual(len(backend.bw_graphs), 1)
dynamo_metadata = self.get_custom_metadata(backend.graphs[0])
fw_metadata = self.get_custom_metadata(backend.fw_graphs[0])
bw_metadata = self.get_custom_metadata(backend.bw_graphs[0])
dynamo_metadata = fx_traceback._get_custom_metadata(backend.graphs[0])
fw_metadata = fx_traceback._get_custom_metadata(backend.fw_graphs[0])
bw_metadata = fx_traceback._get_custom_metadata(backend.bw_graphs[0])
self.assertExpectedInline(
str(dynamo_metadata),
"""\
@ -97,9 +83,9 @@ class AnnotateTests(torch._dynamo.test_case.TestCase):
self.assertEqual(len(backend.fw_graphs), 1)
self.assertEqual(len(backend.bw_graphs), 1)
dynamo_metadata = self.get_custom_metadata(backend.graphs[0])
fw_metadata = self.get_custom_metadata(backend.fw_graphs[0])
bw_metadata = self.get_custom_metadata(backend.bw_graphs[0])
dynamo_metadata = fx_traceback._get_custom_metadata(backend.graphs[0])
fw_metadata = fx_traceback._get_custom_metadata(backend.fw_graphs[0])
bw_metadata = fx_traceback._get_custom_metadata(backend.bw_graphs[0])
self.assertExpectedInline(
str(dynamo_metadata),
"""\
@ -140,9 +126,9 @@ class AnnotateTests(torch._dynamo.test_case.TestCase):
self.assertEqual(len(backend.fw_graphs), 1)
self.assertEqual(len(backend.bw_graphs), 1)
dynamo_metadata = self.get_custom_metadata(backend.graphs[0])
fw_metadata = self.get_custom_metadata(backend.fw_graphs[0])
bw_metadata = self.get_custom_metadata(backend.bw_graphs[0])
dynamo_metadata = fx_traceback._get_custom_metadata(backend.graphs[0])
fw_metadata = fx_traceback._get_custom_metadata(backend.fw_graphs[0])
bw_metadata = fx_traceback._get_custom_metadata(backend.bw_graphs[0])
self.assertExpectedInline(
str(dynamo_metadata),
"""[('call_function', 'p', {'stage': 0})]""", # noqa: B950
@ -198,9 +184,9 @@ class AnnotateTests(torch._dynamo.test_case.TestCase):
self.assertEqual(len(backend.fw_graphs), 1)
self.assertEqual(len(backend.bw_graphs), 1)
dynamo_metadata = self.get_custom_metadata(backend.graphs[0])
fw_metadata = self.get_custom_metadata(backend.fw_graphs[0])
bw_metadata = self.get_custom_metadata(backend.bw_graphs[0])
dynamo_metadata = fx_traceback._get_custom_metadata(backend.graphs[0])
fw_metadata = fx_traceback._get_custom_metadata(backend.fw_graphs[0])
bw_metadata = fx_traceback._get_custom_metadata(backend.bw_graphs[0])
self.assertExpectedInline(
str(dynamo_metadata),
"""\
@ -243,11 +229,11 @@ class AnnotateTests(torch._dynamo.test_case.TestCase):
('call_function', 'detach_2', {'compile_inductor': 0})
('call_function', 'detach_3', {'compile_inductor': 0})
('get_attr', 'fw_graph0', {'compile_inductor': 0})
[]
[('placeholder', 'arg0_1', {'compile_inductor': 0}), ('placeholder', 'arg1_1', {'compile_inductor': 0}), ('placeholder', 'arg2_1', {'compile_inductor': 0}), ('placeholder', 'arg3_1', {'compile_inductor': 0}), ('placeholder', 'arg4_1', {'compile_inductor': 0}), ('call_function', 'mul', {'compile_inductor': 0}), ('output', 'output', {'compile_inductor': 0})]
('get_attr', 'joint_graph0', {'compile_inductor': 0})
[]
[('placeholder', 'arg0_1', {'compile_inductor': 0}), ('placeholder', 'arg1_1', {'compile_inductor': 0}), ('placeholder', 'arg2_1', {'compile_inductor': 0}), ('placeholder', 'arg3_1', {'compile_inductor': 0}), ('placeholder', 'arg4_1', {'compile_inductor': 0}), ('placeholder', 'arg5_1', {'compile_inductor': 0}), ('call_function', 'mul_1', {'compile_inductor': 0}), ('call_function', 'mul_2', {'compile_inductor': 0}), ('call_function', 'add', {'compile_inductor': 0}), ('output', 'output', {'compile_inductor': 0})]
('get_attr', 'mask_graph0', {'compile_inductor': 0})
[('call_function', 'ge', {'compile_inductor': 0})]
[('placeholder', 'arg0_1', {'compile_inductor': 0}), ('placeholder', 'arg1_1', {'compile_inductor': 0}), ('placeholder', 'arg2_1', {'compile_inductor': 0}), ('placeholder', 'arg3_1', {'compile_inductor': 0}), ('call_function', 'ge', {'compile_inductor': 0}), ('output', 'output', {'compile_inductor': 0})]
('call_function', 'flex_attention_backward', {'compile_inductor': 0})
('call_function', 'getitem_3', {'compile_inductor': 0})
('call_function', 'getitem_4', {'compile_inductor': 0})

View File

@ -37,7 +37,30 @@ from torch._functorch.aot_autograd import (
aot_export_joint_with_descriptors,
)
from torch._guards import tracing, TracingContext
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
from torch.testing._internal.common_utils import requires_cuda, run_tests, TestCase
def graph_capture(model, inputs, with_export):
gm = model
fake_mode = None
if with_export:
with (
torch._dynamo.config.patch(install_free_tensors=True),
fx_traceback.preserve_node_meta(),
):
# TODO: switch to use the official graph_capture API once it is ready
gm = _dynamo_graph_capture_for_export(model)(*inputs)
fake_mode = gm.meta.get("fake_mode", None)
with tracing(TracingContext(fake_mode)):
with ExitStack() as stack:
joint_with_descriptors = aot_export_joint_with_descriptors(
stack,
model,
inputs,
)
return joint_with_descriptors.graph_module
class TestAOTJointWithDescriptors(TestCase):
@ -778,40 +801,128 @@ class inner_f(torch.nn.Module):
return y - 1
inputs = (torch.randn(4, 3),)
for with_export in [False]: # TODO: make dynamo work for annotation
with ExitStack() as stack:
model = SimpleLinear()
fake_mode = None
stack.enter_context(fx_traceback.preserve_node_meta())
if with_export:
stack.enter_context(
torch._dynamo.config.patch(install_free_tensors=True)
)
# TODO: switch to use the official graph_capture API once it is ready
model = _dynamo_graph_capture_for_export(model)(*inputs)
fake_mode = model.meta.get("fake_mode", None)
stack.enter_context(tracing(TracingContext(fake_mode)))
joint_with_descriptors = aot_export_joint_with_descriptors(
stack, model, inputs, decompositions={}
for with_export in [True, False]:
graph_module = graph_capture(model, inputs, with_export)
custom_metadata = fx_traceback._get_custom_metadata(graph_module)
self.assertExpectedInline(
str(custom_metadata),
"""\
('call_function', 't', {'pp_stage': 0})
('call_function', 'addmm', {'pp_stage': 0})
('call_function', 't_1', {'pp_stage': 0})
('call_function', 'mm', {'pp_stage': 0})
('call_function', 't_2', {'pp_stage': 0})
('call_function', 'sum_1', {'pp_stage': 0})
('call_function', 'view', {'pp_stage': 0})
('call_function', 't_3', {'pp_stage': 0})""",
)
for node in joint_with_descriptors.graph_module.graph.nodes:
if node.op in ("placeholder", "output"):
continue
if node.target != torch.ops.aten.sub.Tensor and node.op not in (
"placeholder",
"output",
):
self.assertTrue(node.meta["custom"], {"pp_stage": 0})
elif node.target == torch.ops.aten.sub.Tensor:
if "custom" in node.meta:
self.assertTrue(node.meta.get("custom", {}), {})
else:
raise AssertionError(f"Node not checked: {node}, {node.target}")
@requires_cuda
def test_preserve_annotate_flex_attention(self):
def score_mod(score, b, h, m, n):
return score
def _get_block_causal_mask_mod(seq_idx):
def block_causal_mask(b, h, q_idx, kv_idx):
# must use this more complicated mask_mod so autograd seq_nr increases
return (seq_idx[b, q_idx] == seq_idx[b, kv_idx]) & (q_idx >= kv_idx)
return block_causal_mask
a = 12
b = 24
batch_size = 2
seqlen = a * b
device = "cuda"
# Create seq_idx tensor - maps each position to a document/sequence ID
# Example: Split sequence into 2 documents for each batch
# First half (0:384) belongs to document 0, second half (384:768) to document 1
seq_idx = torch.zeros(batch_size, seqlen, dtype=torch.int32, device=device)
seq_idx[:, seqlen // 2 :] = 1 # Second half belongs to document 1
# Get the mask_mod function with seq_idx captured in closure
mask_mod = _get_block_causal_mask_mod(seq_idx)
# Create block_mask with the mask_mod function (which only takes 4 args)
# Note: We don't compile create_block_mask itself, just flex_attention
block_mask = create_block_mask(mask_mod, None, None, seqlen, seqlen)
class FlexAttentionModule(torch.nn.Module):
"""Flex attention submodule similar to the sdpa in Llama3 Attention"""
def forward(self, xq, xk, xv):
"""
Args:
xq: Query tensor (bs, n_heads, seqlen, head_dim)
xk: Key tensor (bs, n_heads, seqlen, head_dim)
xv: Value tensor (bs, n_heads, seqlen, head_dim)
Returns:
Output tensor (bs, n_heads, seqlen, head_dim)
"""
with fx_traceback.annotate({"compile_with_inductor": "flex_attention"}):
output = flex_attention(
xq, xk, xv, block_mask=block_mask, score_mod=score_mod
)
return output
# Model configuration
n_heads = 4
head_dim = 64
# Create input tensors in the shape expected by FlexAttentionModule
# Shape: (bs, n_heads, seqlen, head_dim)
xq = torch.randn(
batch_size, n_heads, seqlen, head_dim, requires_grad=True, device=device
)
xk = torch.randn(
batch_size, n_heads, seqlen, head_dim, requires_grad=True, device=device
)
xv = torch.randn(
batch_size, n_heads, seqlen, head_dim, requires_grad=True, device=device
)
model = FlexAttentionModule().to(device)
inputs = (xq, xk, xv)
gm = graph_capture(model, inputs, with_export=True)
custom_metadata = fx_traceback._get_custom_metadata(gm)
# not using assertExpectedInline because some CI runs has fewer detach nodes in graph
# than other CI runs, so we can't use a fixed string to compare against
self.assertTrue(
"('get_attr', 'sdpa_score0', {'compile_with_inductor': 'flex_attention'})"
in custom_metadata
)
self.assertTrue(
"('get_attr', 'sdpa_mask0', {'compile_with_inductor': 'flex_attention'})"
in custom_metadata
)
self.assertTrue(
"('call_function', 'flex_attention', {'compile_with_inductor': 'flex_attention'})"
in custom_metadata
)
self.assertTrue(
"('get_attr', 'fw_graph0', {'compile_with_inductor': 'flex_attention'})"
in custom_metadata
)
self.assertTrue(
"('get_attr', 'joint_graph0', {'compile_with_inductor': 'flex_attention'})"
in custom_metadata
)
self.assertTrue(
"('get_attr', 'mask_graph0', {'compile_with_inductor': 'flex_attention'})"
in custom_metadata
)
self.assertTrue(
"('call_function', 'flex_attention_backward', {'compile_with_inductor': 'flex_attention'})"
in custom_metadata
)
if __name__ == "__main__":

View File

@ -404,28 +404,28 @@ def root_module_when_exporting_non_strict(flat_fn):
return None
def _copy_fwd_metadata_to_bw_nodes(fx_g):
def _is_forward_node_with_seq_nr(node):
def _is_forward_node_with_seq_nr(node: torch.fx.Node) -> bool:
# For now, assume that if nn_module_stack_metadata is populated, this
# node is from the forward. Ignore nodes without `seq_nr`.
# TODO(future): there is likely a less brittle way to do this by walking
# the descendants of graph inputs corresponding to fwd inputs, didn't
# seem obvious at first glance on how to partition graph inputs into
# fwd vs bwd without relying on string names.
return (
node.meta.get("partitioner_tag") != "is_backward" and "seq_nr" in node.meta
)
return node.meta.get("partitioner_tag") != "is_backward" and "seq_nr" in node.meta
def _is_backward_node_with_seq_nr(node):
def _is_backward_node_with_seq_nr(node: torch.fx.Node) -> bool:
# For now, assume that if nn_module_stack_metadata is not populated,
# this node is from the backward. Ignore nodes without `seq_nr`.
# TODO(future): there is likely a less brittle way to do this, same
# as with the forward.
return (
node.meta.get("partitioner_tag") == "is_backward" and "seq_nr" in node.meta
)
return node.meta.get("partitioner_tag") == "is_backward" and "seq_nr" in node.meta
fwd_seq_nr_to_node = {}
def _collect_fwd_nodes_from_subgraph(
fx_g: torch.fx.GraphModule, fwd_seq_nr_to_node: dict[str, torch.fx.Node]
) -> None:
"""Collect forward nodes from a single subgraph into the global mapping."""
for node in fx_g.graph.nodes:
if not _is_forward_node_with_seq_nr(node):
continue
@ -435,11 +435,17 @@ def _copy_fwd_metadata_to_bw_nodes(fx_g):
# that the current op did not create an autograd node, and there
# is no corresponding backward node, so we skip.
continue
fwd_seq_nr_to_node[node.meta["seq_nr"]] = node
fwd_seq_nr_to_node[seq_nr] = node
def _copy_metadata_to_bw_nodes_in_subgraph(
fx_g: torch.fx.GraphModule, fwd_seq_nr_to_node: dict[str, torch.fx.Node]
) -> None:
"""Copy metadata from forward nodes to backward nodes in a single subgraph."""
for node in fx_g.graph.nodes:
if not _is_backward_node_with_seq_nr(node):
continue
# fwd_node should always exist, but handle non-existence just in case
fwd_node = fwd_seq_nr_to_node.get(node.meta["seq_nr"])
if fwd_node is not None:
@ -449,7 +455,7 @@ def _copy_fwd_metadata_to_bw_nodes(fx_g):
node.meta["custom"] = fwd_node.meta.get("custom")
def copy_fwd_metadata_to_bw_nodes(fx_g):
def copy_fwd_metadata_to_bw_nodes(fx_g: torch.fx.GraphModule) -> None:
"""
Input: `fx_g` which contains the joint fwd+bwd FX graph created by
aot_autograd.
@ -458,15 +464,25 @@ def copy_fwd_metadata_to_bw_nodes(fx_g):
to backward nodes, using the `seq_nr` field as a one-to-many mapping
from forward node to backward node. This metadata is useful for performance
profiling and debugging.
This function supports matching forward and backward nodes across different
subgraphs (e.g., in recursive submodules from HOPs), enabling backward nodes
in any submodule to match forward nodes in any submodule.
"""
# Copy the metadata recursively - useful for HOPs
for node in fx_g.graph.nodes:
if node.op == "get_attr":
submod = getattr(fx_g, node.target)
# Build a global mapping of seq_nr to forward nodes across all subgraphs
fwd_seq_nr_to_node: dict[str, torch.fx.Node] = {}
# First pass: collect all forward nodes from all subgraphs
for submod in fx_g.modules():
if isinstance(submod, torch.fx.GraphModule):
copy_fwd_metadata_to_bw_nodes(submod)
_copy_fwd_metadata_to_bw_nodes(fx_g)
_collect_fwd_nodes_from_subgraph(submod, fwd_seq_nr_to_node)
# Second pass: copy metadata to backward nodes in all subgraphs
# using the global forward mapping
for submod in fx_g.modules():
if isinstance(submod, torch.fx.GraphModule):
_copy_metadata_to_bw_nodes_in_subgraph(submod, fwd_seq_nr_to_node)
def register_buffer_assignment_hook(mod, assigned_buffers):

View File

@ -10,6 +10,7 @@ from torch._utils_internal import signpost_event
from ._compatibility import compatibility
from .graph import Graph
from .graph_module import GraphModule
from .node import Node
@ -388,3 +389,20 @@ def get_graph_provenance_json(graph: Graph) -> dict[str, Any]:
},
)
return {}
def _get_custom_metadata(gm: GraphModule) -> str:
assert isinstance(gm, GraphModule)
def helper(gm: GraphModule):
custom_metadata = []
for node in gm.graph.nodes:
if hasattr(node, "meta") and node.meta.get("custom", None):
custom_metadata.append((node.op, node.name, node.meta["custom"]))
if node.op == "get_attr" and isinstance(
getattr(gm, node.target), GraphModule
):
custom_metadata.append(helper(getattr(gm, node.target)))
return custom_metadata
return "\n".join(str(x) for x in helper(gm))