mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[annotate] add annotate_fn function decorator (#165703)
Example usage: ``` @fx_traceback.annotate_fn({"pp_stage": 1}) def example_function(x): return x * x class SimpleLinear(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(3, 2) def forward(self, x): with fx_traceback.annotate({"pp_stage": 0}): y = self.linear(x) y = example_function(y) return y - 1 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/165703 Approved by: https://github.com/SherlockNoMad
This commit is contained in:
committed by
PyTorch MergeBot
parent
a16fd6b488
commit
75e2a9fae3
@ -922,6 +922,46 @@ class inner_f(torch.nn.Module):
|
||||
in custom_metadata
|
||||
)
|
||||
|
||||
def test_preserve_annotate_function(self):
|
||||
"""Test basic annotate_fn usage"""
|
||||
|
||||
@fx_traceback.annotate_fn({"pp_stage": 1})
|
||||
def example_function(x):
|
||||
return x * x
|
||||
|
||||
class SimpleLinear(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(3, 2)
|
||||
|
||||
def forward(self, x):
|
||||
with fx_traceback.annotate({"pp_stage": 0}):
|
||||
y = self.linear(x)
|
||||
y = example_function(y)
|
||||
return y - 1
|
||||
|
||||
inputs = (torch.randn(4, 3),)
|
||||
model = SimpleLinear()
|
||||
|
||||
for with_export in [True, False]:
|
||||
graph_module = graph_capture(model, inputs, with_export)
|
||||
custom_metadata = fx_traceback._get_custom_metadata(graph_module)
|
||||
self.assertExpectedInline(
|
||||
str(custom_metadata),
|
||||
"""\
|
||||
('call_function', 't', {'pp_stage': 0})
|
||||
('call_function', 'addmm', {'pp_stage': 0})
|
||||
('call_function', 'mul', {'pp_stage': 1})
|
||||
('call_function', 'mul_1', {'pp_stage': 1})
|
||||
('call_function', 'mul_2', {'pp_stage': 1})
|
||||
('call_function', 't_1', {'pp_stage': 0})
|
||||
('call_function', 'mm', {'pp_stage': 0})
|
||||
('call_function', 't_2', {'pp_stage': 0})
|
||||
('call_function', 'sum_1', {'pp_stage': 0})
|
||||
('call_function', 'view', {'pp_stage': 0})
|
||||
('call_function', 't_3', {'pp_stage': 0})""",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
@ -18,6 +18,7 @@ log = logging.getLogger(__name__)
|
||||
|
||||
__all__ = [
|
||||
"annotate",
|
||||
"annotate_fn",
|
||||
"preserve_node_meta",
|
||||
"has_preserved_node_meta",
|
||||
"set_stack_trace",
|
||||
@ -266,9 +267,10 @@ def annotate(annotation_dict: dict):
|
||||
into the FX trace metadata.
|
||||
|
||||
Example:
|
||||
After exiting the context, custom annotations are removed.
|
||||
|
||||
>>> with annotate({"source": "custom_pass", "tag": 42}):
|
||||
... # compute here
|
||||
# After exiting the context, custom annotations are removed.
|
||||
... pass # Your computation here
|
||||
"""
|
||||
|
||||
global current_meta
|
||||
@ -291,6 +293,43 @@ def annotate(annotation_dict: dict):
|
||||
del current_meta["custom"]
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def annotate_fn(annotation_dict: dict):
|
||||
"""
|
||||
A decorator that wraps a function with the annotate context manager.
|
||||
Use this when you want to annotate an entire function instead of a specific code block.
|
||||
|
||||
Note:
|
||||
This API is **not backward compatible** and may evolve in future releases.
|
||||
|
||||
Note:
|
||||
This API is not compatible with fx.symbolic_trace or jit.trace. It's intended
|
||||
to be used with PT2 family of tracers, e.g. torch.export and dynamo.
|
||||
|
||||
Args:
|
||||
annotation_dict (dict): A dictionary of custom key-value pairs to inject
|
||||
into the FX trace metadata for all operations in the function.
|
||||
|
||||
Example:
|
||||
All operations in my_function will have {"pp_stage": 1} in their metadata.
|
||||
|
||||
>>> @annotate_fn({"pp_stage": 1})
|
||||
... def my_function(x):
|
||||
... return x + 1
|
||||
"""
|
||||
from functools import wraps
|
||||
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
with annotate(annotation_dict):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
def set_grad_fn_seq_nr(seq_nr):
|
||||
global current_meta
|
||||
|
Reference in New Issue
Block a user