[annotate] add annotate_fn function decorator (#165703)

Example usage:

```
        @fx_traceback.annotate_fn({"pp_stage": 1})
        def example_function(x):
            return x * x

        class SimpleLinear(nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = nn.Linear(3, 2)

            def forward(self, x):
                with fx_traceback.annotate({"pp_stage": 0}):
                    y = self.linear(x)
                y = example_function(y)
                return y - 1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165703
Approved by: https://github.com/SherlockNoMad
This commit is contained in:
Shangdi Yu
2025-10-17 20:10:49 +00:00
committed by PyTorch MergeBot
parent a16fd6b488
commit 75e2a9fae3
2 changed files with 81 additions and 2 deletions

View File

@ -922,6 +922,46 @@ class inner_f(torch.nn.Module):
in custom_metadata
)
def test_preserve_annotate_function(self):
"""Test basic annotate_fn usage"""
@fx_traceback.annotate_fn({"pp_stage": 1})
def example_function(x):
return x * x
class SimpleLinear(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(3, 2)
def forward(self, x):
with fx_traceback.annotate({"pp_stage": 0}):
y = self.linear(x)
y = example_function(y)
return y - 1
inputs = (torch.randn(4, 3),)
model = SimpleLinear()
for with_export in [True, False]:
graph_module = graph_capture(model, inputs, with_export)
custom_metadata = fx_traceback._get_custom_metadata(graph_module)
self.assertExpectedInline(
str(custom_metadata),
"""\
('call_function', 't', {'pp_stage': 0})
('call_function', 'addmm', {'pp_stage': 0})
('call_function', 'mul', {'pp_stage': 1})
('call_function', 'mul_1', {'pp_stage': 1})
('call_function', 'mul_2', {'pp_stage': 1})
('call_function', 't_1', {'pp_stage': 0})
('call_function', 'mm', {'pp_stage': 0})
('call_function', 't_2', {'pp_stage': 0})
('call_function', 'sum_1', {'pp_stage': 0})
('call_function', 'view', {'pp_stage': 0})
('call_function', 't_3', {'pp_stage': 0})""",
)
if __name__ == "__main__":
run_tests()

View File

@ -18,6 +18,7 @@ log = logging.getLogger(__name__)
__all__ = [
"annotate",
"annotate_fn",
"preserve_node_meta",
"has_preserved_node_meta",
"set_stack_trace",
@ -266,9 +267,10 @@ def annotate(annotation_dict: dict):
into the FX trace metadata.
Example:
After exiting the context, custom annotations are removed.
>>> with annotate({"source": "custom_pass", "tag": 42}):
... # compute here
# After exiting the context, custom annotations are removed.
... pass # Your computation here
"""
global current_meta
@ -291,6 +293,43 @@ def annotate(annotation_dict: dict):
del current_meta["custom"]
@compatibility(is_backward_compatible=False)
def annotate_fn(annotation_dict: dict):
"""
A decorator that wraps a function with the annotate context manager.
Use this when you want to annotate an entire function instead of a specific code block.
Note:
This API is **not backward compatible** and may evolve in future releases.
Note:
This API is not compatible with fx.symbolic_trace or jit.trace. It's intended
to be used with PT2 family of tracers, e.g. torch.export and dynamo.
Args:
annotation_dict (dict): A dictionary of custom key-value pairs to inject
into the FX trace metadata for all operations in the function.
Example:
All operations in my_function will have {"pp_stage": 1} in their metadata.
>>> @annotate_fn({"pp_stage": 1})
... def my_function(x):
... return x + 1
"""
from functools import wraps
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
with annotate(annotation_dict):
return func(*args, **kwargs)
return wrapper
return decorator
@compatibility(is_backward_compatible=False)
def set_grad_fn_seq_nr(seq_nr):
global current_meta