mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[annotate] Annotation should be mapped across submod (#165202)
The match for backward nodes might be in a different submod, so we should check all submod for potential matches. In flex attention, this could happen if `mask_mod` has operations (such as index) that increase the seq_nr of the forward graph nodes. Then the backward flex_attention nodes cannot find a match in its own subgraph. ``` python test/functorch/test_aot_joint_with_descriptors.py -k preserve_annotate ``` Also tested on torchtitan joint_graph_runner branch. The flex_attention backward nodes are annotated now. ``` NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" LOG_RANK=0 TRAIN_FILE="torchtitan.train" TORCHFT_LIGHTHOUSE="http://localhost:29510" PYTORCH_ALLOC_CONF="expandable_segments:True" torchrun --nproc_per_node=8 --rdzv_backend c10d --rdzv_endpoint="localhost:0" --local-ranks-filter 0 --role rank --tee 3 -m torchtitan.train --job.config_file ./torchtitan/models/llama3/train_configs/debug_model.toml --model.name joint_graph_runner.llama3 --compile.enable --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --model.flavor=debugmodel_flex_attn ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/165202 Approved by: https://github.com/SherlockNoMad
This commit is contained in:
committed by
PyTorch MergeBot
parent
c9b2a09530
commit
5eddbb5e47
@ -18,20 +18,6 @@ def checkpoint_wrapper(fn):
|
|||||||
|
|
||||||
|
|
||||||
class AnnotateTests(torch._dynamo.test_case.TestCase):
|
class AnnotateTests(torch._dynamo.test_case.TestCase):
|
||||||
def get_custom_metadata(self, gm):
|
|
||||||
def helper(gm):
|
|
||||||
custom_metadata = []
|
|
||||||
for node in gm.graph.nodes:
|
|
||||||
if hasattr(node, "meta") and node.meta.get("custom", None):
|
|
||||||
custom_metadata.append((node.op, node.name, node.meta["custom"]))
|
|
||||||
if node.op == "get_attr" and isinstance(
|
|
||||||
getattr(gm, node.target), torch.fx.GraphModule
|
|
||||||
):
|
|
||||||
custom_metadata.append(helper(getattr(gm, node.target)))
|
|
||||||
return custom_metadata
|
|
||||||
|
|
||||||
return "\n".join(str(x) for x in helper(gm))
|
|
||||||
|
|
||||||
def test_annotations(self):
|
def test_annotations(self):
|
||||||
class Mod(torch.nn.Module):
|
class Mod(torch.nn.Module):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
@ -53,9 +39,9 @@ class AnnotateTests(torch._dynamo.test_case.TestCase):
|
|||||||
self.assertEqual(len(backend.fw_graphs), 1)
|
self.assertEqual(len(backend.fw_graphs), 1)
|
||||||
self.assertEqual(len(backend.bw_graphs), 1)
|
self.assertEqual(len(backend.bw_graphs), 1)
|
||||||
|
|
||||||
dynamo_metadata = self.get_custom_metadata(backend.graphs[0])
|
dynamo_metadata = fx_traceback._get_custom_metadata(backend.graphs[0])
|
||||||
fw_metadata = self.get_custom_metadata(backend.fw_graphs[0])
|
fw_metadata = fx_traceback._get_custom_metadata(backend.fw_graphs[0])
|
||||||
bw_metadata = self.get_custom_metadata(backend.bw_graphs[0])
|
bw_metadata = fx_traceback._get_custom_metadata(backend.bw_graphs[0])
|
||||||
self.assertExpectedInline(
|
self.assertExpectedInline(
|
||||||
str(dynamo_metadata),
|
str(dynamo_metadata),
|
||||||
"""\
|
"""\
|
||||||
@ -97,9 +83,9 @@ class AnnotateTests(torch._dynamo.test_case.TestCase):
|
|||||||
self.assertEqual(len(backend.fw_graphs), 1)
|
self.assertEqual(len(backend.fw_graphs), 1)
|
||||||
self.assertEqual(len(backend.bw_graphs), 1)
|
self.assertEqual(len(backend.bw_graphs), 1)
|
||||||
|
|
||||||
dynamo_metadata = self.get_custom_metadata(backend.graphs[0])
|
dynamo_metadata = fx_traceback._get_custom_metadata(backend.graphs[0])
|
||||||
fw_metadata = self.get_custom_metadata(backend.fw_graphs[0])
|
fw_metadata = fx_traceback._get_custom_metadata(backend.fw_graphs[0])
|
||||||
bw_metadata = self.get_custom_metadata(backend.bw_graphs[0])
|
bw_metadata = fx_traceback._get_custom_metadata(backend.bw_graphs[0])
|
||||||
self.assertExpectedInline(
|
self.assertExpectedInline(
|
||||||
str(dynamo_metadata),
|
str(dynamo_metadata),
|
||||||
"""\
|
"""\
|
||||||
@ -140,9 +126,9 @@ class AnnotateTests(torch._dynamo.test_case.TestCase):
|
|||||||
self.assertEqual(len(backend.fw_graphs), 1)
|
self.assertEqual(len(backend.fw_graphs), 1)
|
||||||
self.assertEqual(len(backend.bw_graphs), 1)
|
self.assertEqual(len(backend.bw_graphs), 1)
|
||||||
|
|
||||||
dynamo_metadata = self.get_custom_metadata(backend.graphs[0])
|
dynamo_metadata = fx_traceback._get_custom_metadata(backend.graphs[0])
|
||||||
fw_metadata = self.get_custom_metadata(backend.fw_graphs[0])
|
fw_metadata = fx_traceback._get_custom_metadata(backend.fw_graphs[0])
|
||||||
bw_metadata = self.get_custom_metadata(backend.bw_graphs[0])
|
bw_metadata = fx_traceback._get_custom_metadata(backend.bw_graphs[0])
|
||||||
self.assertExpectedInline(
|
self.assertExpectedInline(
|
||||||
str(dynamo_metadata),
|
str(dynamo_metadata),
|
||||||
"""[('call_function', 'p', {'stage': 0})]""", # noqa: B950
|
"""[('call_function', 'p', {'stage': 0})]""", # noqa: B950
|
||||||
@ -198,9 +184,9 @@ class AnnotateTests(torch._dynamo.test_case.TestCase):
|
|||||||
self.assertEqual(len(backend.fw_graphs), 1)
|
self.assertEqual(len(backend.fw_graphs), 1)
|
||||||
self.assertEqual(len(backend.bw_graphs), 1)
|
self.assertEqual(len(backend.bw_graphs), 1)
|
||||||
|
|
||||||
dynamo_metadata = self.get_custom_metadata(backend.graphs[0])
|
dynamo_metadata = fx_traceback._get_custom_metadata(backend.graphs[0])
|
||||||
fw_metadata = self.get_custom_metadata(backend.fw_graphs[0])
|
fw_metadata = fx_traceback._get_custom_metadata(backend.fw_graphs[0])
|
||||||
bw_metadata = self.get_custom_metadata(backend.bw_graphs[0])
|
bw_metadata = fx_traceback._get_custom_metadata(backend.bw_graphs[0])
|
||||||
self.assertExpectedInline(
|
self.assertExpectedInline(
|
||||||
str(dynamo_metadata),
|
str(dynamo_metadata),
|
||||||
"""\
|
"""\
|
||||||
@ -243,11 +229,11 @@ class AnnotateTests(torch._dynamo.test_case.TestCase):
|
|||||||
('call_function', 'detach_2', {'compile_inductor': 0})
|
('call_function', 'detach_2', {'compile_inductor': 0})
|
||||||
('call_function', 'detach_3', {'compile_inductor': 0})
|
('call_function', 'detach_3', {'compile_inductor': 0})
|
||||||
('get_attr', 'fw_graph0', {'compile_inductor': 0})
|
('get_attr', 'fw_graph0', {'compile_inductor': 0})
|
||||||
[]
|
[('placeholder', 'arg0_1', {'compile_inductor': 0}), ('placeholder', 'arg1_1', {'compile_inductor': 0}), ('placeholder', 'arg2_1', {'compile_inductor': 0}), ('placeholder', 'arg3_1', {'compile_inductor': 0}), ('placeholder', 'arg4_1', {'compile_inductor': 0}), ('call_function', 'mul', {'compile_inductor': 0}), ('output', 'output', {'compile_inductor': 0})]
|
||||||
('get_attr', 'joint_graph0', {'compile_inductor': 0})
|
('get_attr', 'joint_graph0', {'compile_inductor': 0})
|
||||||
[]
|
[('placeholder', 'arg0_1', {'compile_inductor': 0}), ('placeholder', 'arg1_1', {'compile_inductor': 0}), ('placeholder', 'arg2_1', {'compile_inductor': 0}), ('placeholder', 'arg3_1', {'compile_inductor': 0}), ('placeholder', 'arg4_1', {'compile_inductor': 0}), ('placeholder', 'arg5_1', {'compile_inductor': 0}), ('call_function', 'mul_1', {'compile_inductor': 0}), ('call_function', 'mul_2', {'compile_inductor': 0}), ('call_function', 'add', {'compile_inductor': 0}), ('output', 'output', {'compile_inductor': 0})]
|
||||||
('get_attr', 'mask_graph0', {'compile_inductor': 0})
|
('get_attr', 'mask_graph0', {'compile_inductor': 0})
|
||||||
[('call_function', 'ge', {'compile_inductor': 0})]
|
[('placeholder', 'arg0_1', {'compile_inductor': 0}), ('placeholder', 'arg1_1', {'compile_inductor': 0}), ('placeholder', 'arg2_1', {'compile_inductor': 0}), ('placeholder', 'arg3_1', {'compile_inductor': 0}), ('call_function', 'ge', {'compile_inductor': 0}), ('output', 'output', {'compile_inductor': 0})]
|
||||||
('call_function', 'flex_attention_backward', {'compile_inductor': 0})
|
('call_function', 'flex_attention_backward', {'compile_inductor': 0})
|
||||||
('call_function', 'getitem_3', {'compile_inductor': 0})
|
('call_function', 'getitem_3', {'compile_inductor': 0})
|
||||||
('call_function', 'getitem_4', {'compile_inductor': 0})
|
('call_function', 'getitem_4', {'compile_inductor': 0})
|
||||||
|
@ -37,7 +37,30 @@ from torch._functorch.aot_autograd import (
|
|||||||
aot_export_joint_with_descriptors,
|
aot_export_joint_with_descriptors,
|
||||||
)
|
)
|
||||||
from torch._guards import tracing, TracingContext
|
from torch._guards import tracing, TracingContext
|
||||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
|
||||||
|
from torch.testing._internal.common_utils import requires_cuda, run_tests, TestCase
|
||||||
|
|
||||||
|
|
||||||
|
def graph_capture(model, inputs, with_export):
|
||||||
|
gm = model
|
||||||
|
fake_mode = None
|
||||||
|
if with_export:
|
||||||
|
with (
|
||||||
|
torch._dynamo.config.patch(install_free_tensors=True),
|
||||||
|
fx_traceback.preserve_node_meta(),
|
||||||
|
):
|
||||||
|
# TODO: switch to use the official graph_capture API once it is ready
|
||||||
|
gm = _dynamo_graph_capture_for_export(model)(*inputs)
|
||||||
|
fake_mode = gm.meta.get("fake_mode", None)
|
||||||
|
|
||||||
|
with tracing(TracingContext(fake_mode)):
|
||||||
|
with ExitStack() as stack:
|
||||||
|
joint_with_descriptors = aot_export_joint_with_descriptors(
|
||||||
|
stack,
|
||||||
|
model,
|
||||||
|
inputs,
|
||||||
|
)
|
||||||
|
return joint_with_descriptors.graph_module
|
||||||
|
|
||||||
|
|
||||||
class TestAOTJointWithDescriptors(TestCase):
|
class TestAOTJointWithDescriptors(TestCase):
|
||||||
@ -778,40 +801,128 @@ class inner_f(torch.nn.Module):
|
|||||||
return y - 1
|
return y - 1
|
||||||
|
|
||||||
inputs = (torch.randn(4, 3),)
|
inputs = (torch.randn(4, 3),)
|
||||||
|
model = SimpleLinear()
|
||||||
|
|
||||||
for with_export in [False]: # TODO: make dynamo work for annotation
|
for with_export in [True, False]:
|
||||||
with ExitStack() as stack:
|
graph_module = graph_capture(model, inputs, with_export)
|
||||||
model = SimpleLinear()
|
custom_metadata = fx_traceback._get_custom_metadata(graph_module)
|
||||||
fake_mode = None
|
self.assertExpectedInline(
|
||||||
|
str(custom_metadata),
|
||||||
|
"""\
|
||||||
|
('call_function', 't', {'pp_stage': 0})
|
||||||
|
('call_function', 'addmm', {'pp_stage': 0})
|
||||||
|
('call_function', 't_1', {'pp_stage': 0})
|
||||||
|
('call_function', 'mm', {'pp_stage': 0})
|
||||||
|
('call_function', 't_2', {'pp_stage': 0})
|
||||||
|
('call_function', 'sum_1', {'pp_stage': 0})
|
||||||
|
('call_function', 'view', {'pp_stage': 0})
|
||||||
|
('call_function', 't_3', {'pp_stage': 0})""",
|
||||||
|
)
|
||||||
|
|
||||||
stack.enter_context(fx_traceback.preserve_node_meta())
|
@requires_cuda
|
||||||
|
def test_preserve_annotate_flex_attention(self):
|
||||||
|
def score_mod(score, b, h, m, n):
|
||||||
|
return score
|
||||||
|
|
||||||
if with_export:
|
def _get_block_causal_mask_mod(seq_idx):
|
||||||
stack.enter_context(
|
def block_causal_mask(b, h, q_idx, kv_idx):
|
||||||
torch._dynamo.config.patch(install_free_tensors=True)
|
# must use this more complicated mask_mod so autograd seq_nr increases
|
||||||
|
return (seq_idx[b, q_idx] == seq_idx[b, kv_idx]) & (q_idx >= kv_idx)
|
||||||
|
|
||||||
|
return block_causal_mask
|
||||||
|
|
||||||
|
a = 12
|
||||||
|
b = 24
|
||||||
|
batch_size = 2
|
||||||
|
seqlen = a * b
|
||||||
|
device = "cuda"
|
||||||
|
|
||||||
|
# Create seq_idx tensor - maps each position to a document/sequence ID
|
||||||
|
# Example: Split sequence into 2 documents for each batch
|
||||||
|
# First half (0:384) belongs to document 0, second half (384:768) to document 1
|
||||||
|
seq_idx = torch.zeros(batch_size, seqlen, dtype=torch.int32, device=device)
|
||||||
|
seq_idx[:, seqlen // 2 :] = 1 # Second half belongs to document 1
|
||||||
|
|
||||||
|
# Get the mask_mod function with seq_idx captured in closure
|
||||||
|
mask_mod = _get_block_causal_mask_mod(seq_idx)
|
||||||
|
|
||||||
|
# Create block_mask with the mask_mod function (which only takes 4 args)
|
||||||
|
# Note: We don't compile create_block_mask itself, just flex_attention
|
||||||
|
block_mask = create_block_mask(mask_mod, None, None, seqlen, seqlen)
|
||||||
|
|
||||||
|
class FlexAttentionModule(torch.nn.Module):
|
||||||
|
"""Flex attention submodule similar to the sdpa in Llama3 Attention"""
|
||||||
|
|
||||||
|
def forward(self, xq, xk, xv):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
xq: Query tensor (bs, n_heads, seqlen, head_dim)
|
||||||
|
xk: Key tensor (bs, n_heads, seqlen, head_dim)
|
||||||
|
xv: Value tensor (bs, n_heads, seqlen, head_dim)
|
||||||
|
Returns:
|
||||||
|
Output tensor (bs, n_heads, seqlen, head_dim)
|
||||||
|
"""
|
||||||
|
with fx_traceback.annotate({"compile_with_inductor": "flex_attention"}):
|
||||||
|
output = flex_attention(
|
||||||
|
xq, xk, xv, block_mask=block_mask, score_mod=score_mod
|
||||||
)
|
)
|
||||||
# TODO: switch to use the official graph_capture API once it is ready
|
return output
|
||||||
model = _dynamo_graph_capture_for_export(model)(*inputs)
|
|
||||||
fake_mode = model.meta.get("fake_mode", None)
|
|
||||||
|
|
||||||
stack.enter_context(tracing(TracingContext(fake_mode)))
|
# Model configuration
|
||||||
joint_with_descriptors = aot_export_joint_with_descriptors(
|
n_heads = 4
|
||||||
stack, model, inputs, decompositions={}
|
head_dim = 64
|
||||||
)
|
|
||||||
|
|
||||||
for node in joint_with_descriptors.graph_module.graph.nodes:
|
# Create input tensors in the shape expected by FlexAttentionModule
|
||||||
if node.op in ("placeholder", "output"):
|
# Shape: (bs, n_heads, seqlen, head_dim)
|
||||||
continue
|
xq = torch.randn(
|
||||||
if node.target != torch.ops.aten.sub.Tensor and node.op not in (
|
batch_size, n_heads, seqlen, head_dim, requires_grad=True, device=device
|
||||||
"placeholder",
|
)
|
||||||
"output",
|
xk = torch.randn(
|
||||||
):
|
batch_size, n_heads, seqlen, head_dim, requires_grad=True, device=device
|
||||||
self.assertTrue(node.meta["custom"], {"pp_stage": 0})
|
)
|
||||||
elif node.target == torch.ops.aten.sub.Tensor:
|
xv = torch.randn(
|
||||||
if "custom" in node.meta:
|
batch_size, n_heads, seqlen, head_dim, requires_grad=True, device=device
|
||||||
self.assertTrue(node.meta.get("custom", {}), {})
|
)
|
||||||
else:
|
|
||||||
raise AssertionError(f"Node not checked: {node}, {node.target}")
|
model = FlexAttentionModule().to(device)
|
||||||
|
inputs = (xq, xk, xv)
|
||||||
|
|
||||||
|
gm = graph_capture(model, inputs, with_export=True)
|
||||||
|
|
||||||
|
custom_metadata = fx_traceback._get_custom_metadata(gm)
|
||||||
|
|
||||||
|
# not using assertExpectedInline because some CI runs has fewer detach nodes in graph
|
||||||
|
# than other CI runs, so we can't use a fixed string to compare against
|
||||||
|
|
||||||
|
self.assertTrue(
|
||||||
|
"('get_attr', 'sdpa_score0', {'compile_with_inductor': 'flex_attention'})"
|
||||||
|
in custom_metadata
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
"('get_attr', 'sdpa_mask0', {'compile_with_inductor': 'flex_attention'})"
|
||||||
|
in custom_metadata
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
"('call_function', 'flex_attention', {'compile_with_inductor': 'flex_attention'})"
|
||||||
|
in custom_metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertTrue(
|
||||||
|
"('get_attr', 'fw_graph0', {'compile_with_inductor': 'flex_attention'})"
|
||||||
|
in custom_metadata
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
"('get_attr', 'joint_graph0', {'compile_with_inductor': 'flex_attention'})"
|
||||||
|
in custom_metadata
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
"('get_attr', 'mask_graph0', {'compile_with_inductor': 'flex_attention'})"
|
||||||
|
in custom_metadata
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
"('call_function', 'flex_attention_backward', {'compile_with_inductor': 'flex_attention'})"
|
||||||
|
in custom_metadata
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -404,28 +404,28 @@ def root_module_when_exporting_non_strict(flat_fn):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _copy_fwd_metadata_to_bw_nodes(fx_g):
|
def _is_forward_node_with_seq_nr(node: torch.fx.Node) -> bool:
|
||||||
def _is_forward_node_with_seq_nr(node):
|
# For now, assume that if nn_module_stack_metadata is populated, this
|
||||||
# For now, assume that if nn_module_stack_metadata is populated, this
|
# node is from the forward. Ignore nodes without `seq_nr`.
|
||||||
# node is from the forward. Ignore nodes without `seq_nr`.
|
# TODO(future): there is likely a less brittle way to do this by walking
|
||||||
# TODO(future): there is likely a less brittle way to do this by walking
|
# the descendants of graph inputs corresponding to fwd inputs, didn't
|
||||||
# the descendants of graph inputs corresponding to fwd inputs, didn't
|
# seem obvious at first glance on how to partition graph inputs into
|
||||||
# seem obvious at first glance on how to partition graph inputs into
|
# fwd vs bwd without relying on string names.
|
||||||
# fwd vs bwd without relying on string names.
|
return node.meta.get("partitioner_tag") != "is_backward" and "seq_nr" in node.meta
|
||||||
return (
|
|
||||||
node.meta.get("partitioner_tag") != "is_backward" and "seq_nr" in node.meta
|
|
||||||
)
|
|
||||||
|
|
||||||
def _is_backward_node_with_seq_nr(node):
|
|
||||||
# For now, assume that if nn_module_stack_metadata is not populated,
|
|
||||||
# this node is from the backward. Ignore nodes without `seq_nr`.
|
|
||||||
# TODO(future): there is likely a less brittle way to do this, same
|
|
||||||
# as with the forward.
|
|
||||||
return (
|
|
||||||
node.meta.get("partitioner_tag") == "is_backward" and "seq_nr" in node.meta
|
|
||||||
)
|
|
||||||
|
|
||||||
fwd_seq_nr_to_node = {}
|
def _is_backward_node_with_seq_nr(node: torch.fx.Node) -> bool:
|
||||||
|
# For now, assume that if nn_module_stack_metadata is not populated,
|
||||||
|
# this node is from the backward. Ignore nodes without `seq_nr`.
|
||||||
|
# TODO(future): there is likely a less brittle way to do this, same
|
||||||
|
# as with the forward.
|
||||||
|
return node.meta.get("partitioner_tag") == "is_backward" and "seq_nr" in node.meta
|
||||||
|
|
||||||
|
|
||||||
|
def _collect_fwd_nodes_from_subgraph(
|
||||||
|
fx_g: torch.fx.GraphModule, fwd_seq_nr_to_node: dict[str, torch.fx.Node]
|
||||||
|
) -> None:
|
||||||
|
"""Collect forward nodes from a single subgraph into the global mapping."""
|
||||||
for node in fx_g.graph.nodes:
|
for node in fx_g.graph.nodes:
|
||||||
if not _is_forward_node_with_seq_nr(node):
|
if not _is_forward_node_with_seq_nr(node):
|
||||||
continue
|
continue
|
||||||
@ -435,11 +435,17 @@ def _copy_fwd_metadata_to_bw_nodes(fx_g):
|
|||||||
# that the current op did not create an autograd node, and there
|
# that the current op did not create an autograd node, and there
|
||||||
# is no corresponding backward node, so we skip.
|
# is no corresponding backward node, so we skip.
|
||||||
continue
|
continue
|
||||||
fwd_seq_nr_to_node[node.meta["seq_nr"]] = node
|
fwd_seq_nr_to_node[seq_nr] = node
|
||||||
|
|
||||||
|
|
||||||
|
def _copy_metadata_to_bw_nodes_in_subgraph(
|
||||||
|
fx_g: torch.fx.GraphModule, fwd_seq_nr_to_node: dict[str, torch.fx.Node]
|
||||||
|
) -> None:
|
||||||
|
"""Copy metadata from forward nodes to backward nodes in a single subgraph."""
|
||||||
for node in fx_g.graph.nodes:
|
for node in fx_g.graph.nodes:
|
||||||
if not _is_backward_node_with_seq_nr(node):
|
if not _is_backward_node_with_seq_nr(node):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# fwd_node should always exist, but handle non-existence just in case
|
# fwd_node should always exist, but handle non-existence just in case
|
||||||
fwd_node = fwd_seq_nr_to_node.get(node.meta["seq_nr"])
|
fwd_node = fwd_seq_nr_to_node.get(node.meta["seq_nr"])
|
||||||
if fwd_node is not None:
|
if fwd_node is not None:
|
||||||
@ -449,7 +455,7 @@ def _copy_fwd_metadata_to_bw_nodes(fx_g):
|
|||||||
node.meta["custom"] = fwd_node.meta.get("custom")
|
node.meta["custom"] = fwd_node.meta.get("custom")
|
||||||
|
|
||||||
|
|
||||||
def copy_fwd_metadata_to_bw_nodes(fx_g):
|
def copy_fwd_metadata_to_bw_nodes(fx_g: torch.fx.GraphModule) -> None:
|
||||||
"""
|
"""
|
||||||
Input: `fx_g` which contains the joint fwd+bwd FX graph created by
|
Input: `fx_g` which contains the joint fwd+bwd FX graph created by
|
||||||
aot_autograd.
|
aot_autograd.
|
||||||
@ -458,15 +464,25 @@ def copy_fwd_metadata_to_bw_nodes(fx_g):
|
|||||||
to backward nodes, using the `seq_nr` field as a one-to-many mapping
|
to backward nodes, using the `seq_nr` field as a one-to-many mapping
|
||||||
from forward node to backward node. This metadata is useful for performance
|
from forward node to backward node. This metadata is useful for performance
|
||||||
profiling and debugging.
|
profiling and debugging.
|
||||||
|
|
||||||
|
This function supports matching forward and backward nodes across different
|
||||||
|
subgraphs (e.g., in recursive submodules from HOPs), enabling backward nodes
|
||||||
|
in any submodule to match forward nodes in any submodule.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Copy the metadata recursively - useful for HOPs
|
# Build a global mapping of seq_nr to forward nodes across all subgraphs
|
||||||
for node in fx_g.graph.nodes:
|
fwd_seq_nr_to_node: dict[str, torch.fx.Node] = {}
|
||||||
if node.op == "get_attr":
|
|
||||||
submod = getattr(fx_g, node.target)
|
# First pass: collect all forward nodes from all subgraphs
|
||||||
if isinstance(submod, torch.fx.GraphModule):
|
for submod in fx_g.modules():
|
||||||
copy_fwd_metadata_to_bw_nodes(submod)
|
if isinstance(submod, torch.fx.GraphModule):
|
||||||
_copy_fwd_metadata_to_bw_nodes(fx_g)
|
_collect_fwd_nodes_from_subgraph(submod, fwd_seq_nr_to_node)
|
||||||
|
|
||||||
|
# Second pass: copy metadata to backward nodes in all subgraphs
|
||||||
|
# using the global forward mapping
|
||||||
|
for submod in fx_g.modules():
|
||||||
|
if isinstance(submod, torch.fx.GraphModule):
|
||||||
|
_copy_metadata_to_bw_nodes_in_subgraph(submod, fwd_seq_nr_to_node)
|
||||||
|
|
||||||
|
|
||||||
def register_buffer_assignment_hook(mod, assigned_buffers):
|
def register_buffer_assignment_hook(mod, assigned_buffers):
|
||||||
|
@ -10,6 +10,7 @@ from torch._utils_internal import signpost_event
|
|||||||
|
|
||||||
from ._compatibility import compatibility
|
from ._compatibility import compatibility
|
||||||
from .graph import Graph
|
from .graph import Graph
|
||||||
|
from .graph_module import GraphModule
|
||||||
from .node import Node
|
from .node import Node
|
||||||
|
|
||||||
|
|
||||||
@ -388,3 +389,20 @@ def get_graph_provenance_json(graph: Graph) -> dict[str, Any]:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def _get_custom_metadata(gm: GraphModule) -> str:
|
||||||
|
assert isinstance(gm, GraphModule)
|
||||||
|
|
||||||
|
def helper(gm: GraphModule):
|
||||||
|
custom_metadata = []
|
||||||
|
for node in gm.graph.nodes:
|
||||||
|
if hasattr(node, "meta") and node.meta.get("custom", None):
|
||||||
|
custom_metadata.append((node.op, node.name, node.meta["custom"]))
|
||||||
|
if node.op == "get_attr" and isinstance(
|
||||||
|
getattr(gm, node.target), GraphModule
|
||||||
|
):
|
||||||
|
custom_metadata.append(helper(getattr(gm, node.target)))
|
||||||
|
return custom_metadata
|
||||||
|
|
||||||
|
return "\n".join(str(x) for x in helper(gm))
|
||||||
|
Reference in New Issue
Block a user