mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-20 21:14:14 +08:00 
			
		
		
		
	[annotate] add annotate_fn function decorator (#165703)
Example usage:
```
        @fx_traceback.annotate_fn({"pp_stage": 1})
        def example_function(x):
            return x * x
        class SimpleLinear(nn.Module):
            def __init__(self):
                super().__init__()
                self.linear = nn.Linear(3, 2)
            def forward(self, x):
                with fx_traceback.annotate({"pp_stage": 0}):
                    y = self.linear(x)
                y = example_function(y)
                return y - 1
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165703
Approved by: https://github.com/SherlockNoMad
			
			
This commit is contained in:
		
				
					committed by
					
						 PyTorch MergeBot
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							a16fd6b488
						
					
				
				
					commit
					75e2a9fae3
				
			| @ -922,6 +922,46 @@ class inner_f(torch.nn.Module): | |||||||
|             in custom_metadata |             in custom_metadata | ||||||
|         ) |         ) | ||||||
|  |  | ||||||
|  |     def test_preserve_annotate_function(self): | ||||||
|  |         """Test basic annotate_fn usage""" | ||||||
|  |  | ||||||
|  |         @fx_traceback.annotate_fn({"pp_stage": 1}) | ||||||
|  |         def example_function(x): | ||||||
|  |             return x * x | ||||||
|  |  | ||||||
|  |         class SimpleLinear(nn.Module): | ||||||
|  |             def __init__(self): | ||||||
|  |                 super().__init__() | ||||||
|  |                 self.linear = nn.Linear(3, 2) | ||||||
|  |  | ||||||
|  |             def forward(self, x): | ||||||
|  |                 with fx_traceback.annotate({"pp_stage": 0}): | ||||||
|  |                     y = self.linear(x) | ||||||
|  |                 y = example_function(y) | ||||||
|  |                 return y - 1 | ||||||
|  |  | ||||||
|  |         inputs = (torch.randn(4, 3),) | ||||||
|  |         model = SimpleLinear() | ||||||
|  |  | ||||||
|  |         for with_export in [True, False]: | ||||||
|  |             graph_module = graph_capture(model, inputs, with_export) | ||||||
|  |             custom_metadata = fx_traceback._get_custom_metadata(graph_module) | ||||||
|  |             self.assertExpectedInline( | ||||||
|  |                 str(custom_metadata), | ||||||
|  |                 """\ | ||||||
|  | ('call_function', 't', {'pp_stage': 0}) | ||||||
|  | ('call_function', 'addmm', {'pp_stage': 0}) | ||||||
|  | ('call_function', 'mul', {'pp_stage': 1}) | ||||||
|  | ('call_function', 'mul_1', {'pp_stage': 1}) | ||||||
|  | ('call_function', 'mul_2', {'pp_stage': 1}) | ||||||
|  | ('call_function', 't_1', {'pp_stage': 0}) | ||||||
|  | ('call_function', 'mm', {'pp_stage': 0}) | ||||||
|  | ('call_function', 't_2', {'pp_stage': 0}) | ||||||
|  | ('call_function', 'sum_1', {'pp_stage': 0}) | ||||||
|  | ('call_function', 'view', {'pp_stage': 0}) | ||||||
|  | ('call_function', 't_3', {'pp_stage': 0})""", | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |  | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|     run_tests() |     run_tests() | ||||||
|  | |||||||
| @ -18,6 +18,7 @@ log = logging.getLogger(__name__) | |||||||
|  |  | ||||||
| __all__ = [ | __all__ = [ | ||||||
|     "annotate", |     "annotate", | ||||||
|  |     "annotate_fn", | ||||||
|     "preserve_node_meta", |     "preserve_node_meta", | ||||||
|     "has_preserved_node_meta", |     "has_preserved_node_meta", | ||||||
|     "set_stack_trace", |     "set_stack_trace", | ||||||
| @ -266,9 +267,10 @@ def annotate(annotation_dict: dict): | |||||||
|             into the FX trace metadata. |             into the FX trace metadata. | ||||||
|  |  | ||||||
|     Example: |     Example: | ||||||
|  |         After exiting the context, custom annotations are removed. | ||||||
|  |  | ||||||
|         >>> with annotate({"source": "custom_pass", "tag": 42}): |         >>> with annotate({"source": "custom_pass", "tag": 42}): | ||||||
|         ...     # compute here |         ...     pass  # Your computation here | ||||||
|         # After exiting the context, custom annotations are removed. |  | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     global current_meta |     global current_meta | ||||||
| @ -291,6 +293,43 @@ def annotate(annotation_dict: dict): | |||||||
|             del current_meta["custom"] |             del current_meta["custom"] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @compatibility(is_backward_compatible=False) | ||||||
|  | def annotate_fn(annotation_dict: dict): | ||||||
|  |     """ | ||||||
|  |     A decorator that wraps a function with the annotate context manager. | ||||||
|  |     Use this when you want to annotate an entire function instead of a specific code block. | ||||||
|  |  | ||||||
|  |     Note: | ||||||
|  |         This API is **not backward compatible** and may evolve in future releases. | ||||||
|  |  | ||||||
|  |     Note: | ||||||
|  |         This API is not compatible with fx.symbolic_trace or jit.trace. It's intended | ||||||
|  |         to be used with PT2 family of tracers, e.g. torch.export and dynamo. | ||||||
|  |  | ||||||
|  |     Args: | ||||||
|  |         annotation_dict (dict): A dictionary of custom key-value pairs to inject | ||||||
|  |             into the FX trace metadata for all operations in the function. | ||||||
|  |  | ||||||
|  |     Example: | ||||||
|  |         All operations in my_function will have {"pp_stage": 1} in their metadata. | ||||||
|  |  | ||||||
|  |         >>> @annotate_fn({"pp_stage": 1}) | ||||||
|  |         ... def my_function(x): | ||||||
|  |         ...     return x + 1 | ||||||
|  |     """ | ||||||
|  |     from functools import wraps | ||||||
|  |  | ||||||
|  |     def decorator(func): | ||||||
|  |         @wraps(func) | ||||||
|  |         def wrapper(*args, **kwargs): | ||||||
|  |             with annotate(annotation_dict): | ||||||
|  |                 return func(*args, **kwargs) | ||||||
|  |  | ||||||
|  |         return wrapper | ||||||
|  |  | ||||||
|  |     return decorator | ||||||
|  |  | ||||||
|  |  | ||||||
| @compatibility(is_backward_compatible=False) | @compatibility(is_backward_compatible=False) | ||||||
| def set_grad_fn_seq_nr(seq_nr): | def set_grad_fn_seq_nr(seq_nr): | ||||||
|     global current_meta |     global current_meta | ||||||
|  | |||||||
		Reference in New Issue
	
	Block a user