Compare commits

...

1 Commits

Author SHA1 Message Date
cb930ff08c Tag gradient acc in node 2025-11-11 14:25:48 -08:00
6 changed files with 67 additions and 4 deletions

View File

@ -950,7 +950,7 @@ SeqNr|OrigAten|SrcFn|FwdSrcFn
2|aten.threshold_backward.default||relu
1|aten.native_batch_norm_backward.default||batch_norm
0|aten.convolution_backward.default||conv2d
11|aten.add.Tensor||l1_loss
11|aten.add.Tensor||
"""
),
)

View File

@ -1092,6 +1092,57 @@ class inner_f(torch.nn.Module):
)
self.assertEqual(joint._aot_state.fw_metadata.static_input_indices, [0, 1])
def test_no_annotation_on_gradient_acc_nodes(self):
"""Test basic linear module with aot_export_joint_with_descriptors"""
class SimpleLinear(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(3, 2)
self.linear2 = nn.Linear(3, 2)
def forward(self, x):
with fx_traceback.annotate({"test": 1}):
return self.linear(x) - self.linear2(x)
model = SimpleLinear()
inputs = (torch.randn(4, 3, requires_grad=True),)
graph_module = graph_capture(model, inputs, True)
add_nodes = graph_module.graph.find_nodes(
op="call_function", target=torch.ops.aten.add.Tensor
)
self.assertEqual(len(add_nodes), 1)
gradient_acc_node = add_nodes[0]
self.assertTrue(gradient_acc_node.meta["is_gradient_acc"])
self.assertEqual(gradient_acc_node.meta.get("custom", {}), {})
custom_metadata = fx_traceback._get_custom_metadata(graph_module)
self.assertExpectedInline(
str(custom_metadata),
"""\
('call_function', 't', {'test': 1})
('call_function', 'addmm', {'test': 1})
('call_function', 't_1', {'test': 1})
('call_function', 'addmm_1', {'test': 1})
('call_function', 'sub', {'test': 1})
('call_function', 'neg', {'test': 1})
('call_function', 't_2', {'test': 1})
('call_function', 'mm', {'test': 1})
('call_function', 't_3', {'test': 1})
('call_function', 'mm_1', {'test': 1})
('call_function', 't_4', {'test': 1})
('call_function', 'sum_1', {'test': 1})
('call_function', 'view', {'test': 1})
('call_function', 't_5', {'test': 1})
('call_function', 't_6', {'test': 1})
('call_function', 'mm_2', {'test': 1})
('call_function', 't_7', {'test': 1})
('call_function', 'mm_3', {'test': 1})
('call_function', 't_8', {'test': 1})
('call_function', 'sum_2', {'test': 1})
('call_function', 'view_1', {'test': 1})
('call_function', 't_9', {'test': 1})""",
)
if __name__ == "__main__":
run_tests()

View File

@ -125,9 +125,7 @@ def setup_stacktrace_preservation_hooks(roots: list):
node.register_prehook(get_prehook(forward_node_stack, node._sequence_nr()))
special_stack = forward_node_stack.copy()
special_stack.append(
"Gradient addition node due to multiple use of tensor around:"
)
special_stack.append(fx_traceback.GRADIENT_ACC_SPECIAL_STACK)
node.register_hook(get_posthook(special_stack, node._sequence_nr()))

View File

@ -454,6 +454,11 @@ def _copy_metadata_to_bw_nodes_in_subgraph(
if not _is_backward_node_with_seq_nr(node):
continue
# We exclude gradient accumulation nodes from copying tags
if node.meta.get("is_gradient_acc", False):
annotation_log.debug("is_gradient_acc")
continue
# fwd_node should always exist, but handle non-existence just in case
fwd_node = fwd_seq_nr_to_node.get(node.meta["seq_nr"])
if fwd_node is not None:

View File

@ -187,6 +187,10 @@ class TracerBase:
stack_trace = current_meta.get("stack_trace")
if stack_trace:
node.stack_trace = stack_trace
if fx_traceback.GRADIENT_ACC_SPECIAL_STACK in stack_trace:
node.meta["is_gradient_acc"] = True
# Explicitly set the stack_trace, nn_module_stack and source_fn on the node.meta
# If other meta fields are needed, they can be added here
for field in _COPY_META_FIELDS:

View File

@ -38,6 +38,9 @@ current_meta: dict[str, Any] = {}
current_replay_node: Optional[Node] = None
should_preserve_node_meta = False
GRADIENT_ACC_SPECIAL_STACK = (
"Gradient addition node due to multiple use of tensor around:"
)
# =============================================================================
# FX Metadata Registry for Memory Profiler
# =============================================================================
@ -277,6 +280,8 @@ def annotate(annotation_dict: dict):
tracing system by updating the global `current_meta["custom"]` dictionary.
The annotations are automatically reverted after the context exits.
Gradient accumulation nodes will not be annotated.
This is intended for advanced users who need to attach additional metadata to the fx nodes
(e.g., for debugging, analysis, or external tooling) during export tracing.