mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-26 08:34:52 +08:00 
			
		
		
		
	* Fix leaf modules in Transformer [ghstack-poisoned] * Fix tuple type annotations [ghstack-poisoned] * Generalize dict key check in `create-arg` (#51927) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/51927 Test Plan: Imported from OSS Reviewed By: pbelevich Differential Revision: D26329655 Pulled By: jamesr66a fbshipit-source-id: a15e7d9564551521af12a8fde1c7524856f0cbc2
		
			
				
	
	
		
			224 lines
		
	
	
		
			9.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			224 lines
		
	
	
		
			9.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import dis
 | |
| import torch
 | |
| import inspect
 | |
| import operator
 | |
| 
 | |
| from .graph import magic_methods, reflectable_magic_methods, Graph
 | |
| from typing import Tuple, Dict, Optional, Iterable, Any, Iterator
 | |
| from .node import Target, Node, Argument, base_types, map_aggregate
 | |
| 
 | |
| class TracerBase:
 | |
|     graph: Graph
 | |
| 
 | |
|     def create_node(self, kind : str, target : Target,
 | |
|                     args : Tuple[Argument, ...], kwargs : Dict[str, Argument], name : Optional[str] = None,
 | |
|                     type_expr : Optional[Any] = None) -> Node:
 | |
|         """
 | |
|         Inserts a graph node given target, args, kwargs, and name.
 | |
| 
 | |
|         This method can be overridden to do extra checking, validation, or
 | |
|         modification of values used in node creation. For example, one might
 | |
|         want to disallow in-place operations from being recorded.
 | |
|         """
 | |
|         return self.graph.create_node(kind, target, args, kwargs, name, type_expr)
 | |
| 
 | |
|     def proxy(self, node: Node) -> 'Proxy':
 | |
|         return Proxy(node, self)
 | |
| 
 | |
|     def create_proxy(self, kind: str, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any],
 | |
|                      name: Optional[str] = None, type_expr : Optional[Any] = None):
 | |
|         '''
 | |
|         Create a Node from the given arguments, then return the Node
 | |
|         wrapped in a Proxy object.
 | |
| 
 | |
|         If kind = 'placeholder', then we're creating a Node that
 | |
|         represents the parameter of a function. If we need to encode
 | |
|         a default parameter, we use the ``args`` tuple. ``args`` is
 | |
|         otherwise empty for ``placeholder`` Nodes.
 | |
|         '''
 | |
|         args_ = self.create_arg(args)
 | |
|         kwargs_ = self.create_arg(kwargs)
 | |
|         assert isinstance(args_, tuple)
 | |
|         assert isinstance(kwargs_, dict)
 | |
|         return self.proxy(self.create_node(kind, target, args_, kwargs_, name, type_expr))
 | |
| 
 | |
|     def create_arg(self, a: Any) -> Argument:
 | |
|         """
 | |
|         A method that lowers the objects seen as arguments during symbolic evaluation
 | |
|         into Argument types that can be stored in IR.
 | |
| 
 | |
|         Can be override to support more trace-specific types.
 | |
|         """
 | |
|         # aggregates
 | |
|         if isinstance(a, tuple) and hasattr(a, '_fields'):
 | |
|             # NamedTuple constructors don't seem to like getting a generator
 | |
|             # expression as an argument to their constructor, so build this
 | |
|             # intermediate tuple and unpack it into the NamedTuple constructor
 | |
|             args = tuple(self.create_arg(elem) for elem in a)
 | |
|             return type(a)(*args)  # type: ignore
 | |
|         elif isinstance(a, (tuple, list)):
 | |
|             return type(a)(self.create_arg(elem) for elem in a)
 | |
|         elif isinstance(a, dict):
 | |
|             r = {}
 | |
|             for k, v in a.items():
 | |
|                 # Check for invalid dict keys. We do not want a Proxy to appear
 | |
|                 # anywhere within the key. Since keys can be collection types,
 | |
|                 # we iterate through the key with map_aggregate
 | |
|                 k = self.create_arg(k)
 | |
| 
 | |
|                 def no_node(arg):
 | |
|                     if isinstance(arg, Node):
 | |
|                         raise RuntimeError("Keys for dictionaries used as an argument cannot contain a "
 | |
|                                            "Node. Got key: {k}")
 | |
|                 map_aggregate(k, no_node)
 | |
| 
 | |
|                 r[k] = self.create_arg(v)
 | |
|             return r
 | |
|         elif isinstance(a, slice):
 | |
|             return slice(self.create_arg(a.start), self.create_arg(a.stop), self.create_arg(a.step))
 | |
| 
 | |
|         if isinstance(a, Proxy):
 | |
|             # base case: we unwrap the Proxy object
 | |
|             return a.node
 | |
|         elif isinstance(a, base_types) or a is None or a is ...:
 | |
|             return a
 | |
| 
 | |
|         raise NotImplementedError(f"argument of type: {type(a)}")
 | |
| 
 | |
|     def to_bool(self, obj: 'Proxy') -> bool:
 | |
|         """Called when a proxy object is being converted to a boolean, such as
 | |
|         when used in control flow.  Normally we don't know what to do because
 | |
|         we don't know the value of the proxy, but a custom tracer can attach more
 | |
|         information to the graph node using create_node and can choose to return a value.
 | |
|         """
 | |
|         raise TraceError('symbolically traced variables cannot be used as inputs to control flow')
 | |
| 
 | |
|     def iter(self, obj: 'Proxy') -> Iterator:
 | |
|         """Called when a proxy object is being iterated over, such as
 | |
|         when used in control flow.  Normally we don't know what to do because
 | |
|         we don't know the value of the proxy, but a custom tracer can attach more
 | |
|         information to the graph node using create_node and can choose to return an iterator.
 | |
|         """
 | |
|         raise TraceError('Proxy object cannot be iterated. '
 | |
|                          'This can be attempted when used in a for loop or as a *args or **kwargs function argument.')
 | |
| 
 | |
|     def keys(self, obj: 'Proxy') -> Any:
 | |
|         """Called when a proxy object is has the keys() method called.
 | |
|         This is what happens when ** is called on a proxy. This should return an
 | |
|         iterator it ** is suppose to work in your custom tracer.
 | |
|         """
 | |
|         return Attribute(obj, 'keys')()
 | |
| 
 | |
| 
 | |
| # used in Proxy object when just appending to the graph while not tracing.
 | |
| class GraphAppendingTracer(TracerBase):
 | |
|     def __init__(self, graph: Graph):
 | |
|         super().__init__()
 | |
|         self.graph = graph
 | |
| 
 | |
| class TraceError(ValueError):
 | |
|     pass
 | |
| 
 | |
| 
 | |
| class Proxy:
 | |
|     """
 | |
|     ``Proxy`` objects are ``Node`` wrappers that flow through the
 | |
|     program during symbolic tracing and record all the operations
 | |
|     (``torch`` function calls, method calls, operators) that they touch
 | |
|     into the growing FX Graph.
 | |
| 
 | |
|     If you're doing graph transforms, you can wrap your own ``Proxy``
 | |
|     method around a raw ``Node`` so that you can use the overloaded
 | |
|     operators to add additional things to a ``Graph``.
 | |
|     """
 | |
|     def __init__(self, node: Node, tracer: 'Optional[TracerBase]' = None):
 | |
|         if tracer is None:
 | |
|             # This allows you to create a Proxy object around a raw Node
 | |
|             tracer = GraphAppendingTracer(node.graph)
 | |
|         self.tracer = tracer
 | |
|         self.node = node
 | |
| 
 | |
|     def __repr__(self) -> str:
 | |
|         return f'Proxy({self.node.name})'
 | |
| 
 | |
|     def __getattr__(self, k) -> 'Attribute':
 | |
|         # note: not added to the graph yet, if this is a method call
 | |
|         # we peephole optimize to the method invocation
 | |
|         return Attribute(self, k)
 | |
| 
 | |
|     def __call__(self, *args, **kwargs) -> 'Proxy':
 | |
|         return self.tracer.create_proxy('call_method', '__call__', (self,) + args, kwargs)
 | |
| 
 | |
|     def __iter__(self) -> Iterable['Proxy']:
 | |
|         frame = inspect.currentframe()
 | |
|         assert frame is not None
 | |
|         calling_frame = frame.f_back
 | |
|         assert calling_frame is not None
 | |
|         inst = list(dis.get_instructions(calling_frame.f_code))[calling_frame.f_lasti // 2]
 | |
|         if inst.opname == 'UNPACK_SEQUENCE':
 | |
|             return (self[i] for i in range(inst.argval))  # type: ignore
 | |
| 
 | |
|         return self.tracer.iter(self)
 | |
| 
 | |
|     def __bool__(self) -> bool:
 | |
|         return self.tracer.to_bool(self)
 | |
| 
 | |
|     def keys(self):
 | |
|         return self.tracer.keys(self)
 | |
| 
 | |
|     def __len__(self):
 | |
|         raise RuntimeError("'len' is not supported in symbolic tracing by default. If you want "
 | |
|                            "this call to be recorded, please call torch.fx.wrap('len') at "
 | |
|                            "module scope")
 | |
| 
 | |
|     def __torch_function__(self, orig_method, types, args=None, kwargs=None):
 | |
|         args = args if args else ()
 | |
|         kwargs = kwargs if kwargs else {}
 | |
|         if torch.overrides.is_tensor_method_or_property(orig_method):
 | |
|             return self.tracer.create_proxy('call_method', orig_method.__name__, args, kwargs)
 | |
|         else:
 | |
|             return self.tracer.create_proxy('call_function', orig_method, args, kwargs,
 | |
|                                             name=self.tracer.graph._target_to_str(orig_method.__name__))
 | |
| 
 | |
| class Attribute(Proxy):
 | |
|     def __init__(self, root: Proxy, attr: str):
 | |
|         self.root = root
 | |
|         self.attr = attr
 | |
|         self.tracer = root.tracer
 | |
|         self._node: Optional[Node] = None
 | |
| 
 | |
|     @property
 | |
|     def node(self):
 | |
|         # the node for attributes is added lazily, since most will just be method calls
 | |
|         # which do not rely on the getitem call
 | |
|         if self._node is None:
 | |
|             self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node
 | |
|         return self._node
 | |
| 
 | |
|     def __call__(self, *args, **kwargs):
 | |
|         return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs)
 | |
| 
 | |
| for method in magic_methods:
 | |
|     def scope(method):
 | |
|         def impl(*args, **kwargs):
 | |
|             tracer = args[0].tracer
 | |
|             target = getattr(operator, method)
 | |
|             return tracer.create_proxy('call_function', target, args, kwargs)
 | |
|         impl.__name__ = method
 | |
|         as_magic = f'__{method}__'
 | |
|         setattr(Proxy, as_magic, impl)
 | |
|     scope(method)
 | |
| 
 | |
| def _define_reflectable(orig_method_name):
 | |
|     method_name = f'__r{orig_method_name}__'
 | |
| 
 | |
|     def impl(self, rhs):
 | |
|         target = getattr(operator, orig_method_name)
 | |
|         return self.tracer.create_proxy('call_function', target, (rhs, self), {})
 | |
|     impl.__name__ = method_name
 | |
|     impl.__qualname__ = method_name
 | |
|     setattr(Proxy, method_name, impl)
 | |
| 
 | |
| for orig_method_name in reflectable_magic_methods:
 | |
|     _define_reflectable(orig_method_name)
 |