mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[annotate] add annotate_fn function decorator (#165703)
Example usage:
```
@fx_traceback.annotate_fn({"pp_stage": 1})
def example_function(x):
return x * x
class SimpleLinear(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(3, 2)
def forward(self, x):
with fx_traceback.annotate({"pp_stage": 0}):
y = self.linear(x)
y = example_function(y)
return y - 1
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165703
Approved by: https://github.com/SherlockNoMad
This commit is contained in:
committed by
PyTorch MergeBot
parent
a16fd6b488
commit
75e2a9fae3
@ -922,6 +922,46 @@ class inner_f(torch.nn.Module):
|
|||||||
in custom_metadata
|
in custom_metadata
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_preserve_annotate_function(self):
|
||||||
|
"""Test basic annotate_fn usage"""
|
||||||
|
|
||||||
|
@fx_traceback.annotate_fn({"pp_stage": 1})
|
||||||
|
def example_function(x):
|
||||||
|
return x * x
|
||||||
|
|
||||||
|
class SimpleLinear(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.linear = nn.Linear(3, 2)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
with fx_traceback.annotate({"pp_stage": 0}):
|
||||||
|
y = self.linear(x)
|
||||||
|
y = example_function(y)
|
||||||
|
return y - 1
|
||||||
|
|
||||||
|
inputs = (torch.randn(4, 3),)
|
||||||
|
model = SimpleLinear()
|
||||||
|
|
||||||
|
for with_export in [True, False]:
|
||||||
|
graph_module = graph_capture(model, inputs, with_export)
|
||||||
|
custom_metadata = fx_traceback._get_custom_metadata(graph_module)
|
||||||
|
self.assertExpectedInline(
|
||||||
|
str(custom_metadata),
|
||||||
|
"""\
|
||||||
|
('call_function', 't', {'pp_stage': 0})
|
||||||
|
('call_function', 'addmm', {'pp_stage': 0})
|
||||||
|
('call_function', 'mul', {'pp_stage': 1})
|
||||||
|
('call_function', 'mul_1', {'pp_stage': 1})
|
||||||
|
('call_function', 'mul_2', {'pp_stage': 1})
|
||||||
|
('call_function', 't_1', {'pp_stage': 0})
|
||||||
|
('call_function', 'mm', {'pp_stage': 0})
|
||||||
|
('call_function', 't_2', {'pp_stage': 0})
|
||||||
|
('call_function', 'sum_1', {'pp_stage': 0})
|
||||||
|
('call_function', 'view', {'pp_stage': 0})
|
||||||
|
('call_function', 't_3', {'pp_stage': 0})""",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
run_tests()
|
run_tests()
|
||||||
|
|||||||
@ -18,6 +18,7 @@ log = logging.getLogger(__name__)
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"annotate",
|
"annotate",
|
||||||
|
"annotate_fn",
|
||||||
"preserve_node_meta",
|
"preserve_node_meta",
|
||||||
"has_preserved_node_meta",
|
"has_preserved_node_meta",
|
||||||
"set_stack_trace",
|
"set_stack_trace",
|
||||||
@ -266,9 +267,10 @@ def annotate(annotation_dict: dict):
|
|||||||
into the FX trace metadata.
|
into the FX trace metadata.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
After exiting the context, custom annotations are removed.
|
||||||
|
|
||||||
>>> with annotate({"source": "custom_pass", "tag": 42}):
|
>>> with annotate({"source": "custom_pass", "tag": 42}):
|
||||||
... # compute here
|
... pass # Your computation here
|
||||||
# After exiting the context, custom annotations are removed.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
global current_meta
|
global current_meta
|
||||||
@ -291,6 +293,43 @@ def annotate(annotation_dict: dict):
|
|||||||
del current_meta["custom"]
|
del current_meta["custom"]
|
||||||
|
|
||||||
|
|
||||||
|
@compatibility(is_backward_compatible=False)
|
||||||
|
def annotate_fn(annotation_dict: dict):
|
||||||
|
"""
|
||||||
|
A decorator that wraps a function with the annotate context manager.
|
||||||
|
Use this when you want to annotate an entire function instead of a specific code block.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This API is **not backward compatible** and may evolve in future releases.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This API is not compatible with fx.symbolic_trace or jit.trace. It's intended
|
||||||
|
to be used with PT2 family of tracers, e.g. torch.export and dynamo.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
annotation_dict (dict): A dictionary of custom key-value pairs to inject
|
||||||
|
into the FX trace metadata for all operations in the function.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
All operations in my_function will have {"pp_stage": 1} in their metadata.
|
||||||
|
|
||||||
|
>>> @annotate_fn({"pp_stage": 1})
|
||||||
|
... def my_function(x):
|
||||||
|
... return x + 1
|
||||||
|
"""
|
||||||
|
from functools import wraps
|
||||||
|
|
||||||
|
def decorator(func):
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
with annotate(annotation_dict):
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
@compatibility(is_backward_compatible=False)
|
@compatibility(is_backward_compatible=False)
|
||||||
def set_grad_fn_seq_nr(seq_nr):
|
def set_grad_fn_seq_nr(seq_nr):
|
||||||
global current_meta
|
global current_meta
|
||||||
|
|||||||
Reference in New Issue
Block a user